示例#1
0
    def train(self):
        # Initialize the progesss bar in the UI.
        training_progress = sly.Progress('Model training: ', self._epochs * self._train_iters)

        # Initialize the optimizer.
        optimizer = torch.optim.Adam(self._model.parameters(), lr=self.config[LR])
        # Running best loss value to determine which snapshot is the best so far.
        best_val_loss = float('inf')

        for epoch in range(self._epochs):
            sly.logger.info("Starting new epoch", extra={'epoch': self.epoch_flt})
            for train_it, (inputs_cpu, targets_cpu) in enumerate(self._data_loaders[TRAIN]):
                # Switch the model into training mode to enable gradient backpropagation and batch norm running average
                # updates.
                self._model.train()

                # Copy input batch to the GPU, run inference and compute optimization loss.
                inputs_cuda, targets_cuda = Variable(inputs_cpu).cuda(), Variable(targets_cpu).cuda()
                outputs_cuda = self._model(inputs_cuda)
                loss = self._loss_fn(outputs_cuda, targets_cuda)

                # Make a gradient descent step.
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                # Advance UI progess bar.
                training_progress.iter_done_report()
                # Compute fractional epoch value for more precise metrics reporting.
                self.epoch_flt = epoch_float(epoch, train_it + 1, self._train_iters)
                # Report metrics to be plotted in the training chart.
                sly.report_metrics_training(self.epoch_flt, {LOSS: loss.item()})

                # If needed, do validation and snapshotting.
                if self._eval_planner.need_validation(self.epoch_flt):
                    # Compute metrics on the validation dataset.
                    metrics_values_val = self._validation()

                    # Report progress.
                    self._eval_planner.validation_performed()

                    # Check whether the new weights are the best so far on the validation dataset.
                    val_loss = metrics_values_val[LOSS]
                    model_is_best = val_loss < best_val_loss
                    if model_is_best:
                        best_val_loss = val_loss

                    # Save a snapshot with the current weights. Mark whether the snapshot is the best so far in terms of
                    # validation loss.
                    self._save_model_snapshot(model_is_best, opt_data={
                        'epoch': self.epoch_flt,
                        'val_metrics': metrics_values_val,
                    })

            # Report progress
            sly.logger.info("Epoch has finished", extra={'epoch': self.epoch_flt})
示例#2
0
    def train(self):
        progress = sly.Progress('Model training: ', self.epochs * self.train_iters)
        self.model.train()

        lr_decr = self.config['lr_decreasing']
        policy = LRPolicyWithPatience(
            optim_cls=Adam,
            init_lr=self.config['lr'],
            patience=lr_decr['patience'],
            lr_divisor=lr_decr['lr_divisor'],
            model=self.model
        )
        best_val_loss = float('inf')

        for epoch in range(self.epochs):
            logger.info("Before new epoch", extra={'epoch': self.epoch_flt})

            for train_it, (inputs_cpu, targets_cpu) in enumerate(self.data_loaders['train']):
                inputs, targets = inputs_cpu.requires_grad_().cuda(), targets_cpu.requires_grad_().cuda()
                outputs = self.model(inputs)
                loss = self.criterion(outputs, targets)

                policy.optimizer.zero_grad()
                loss.backward()
                policy.optimizer.step()

                metric_values_train = {'loss': loss.item()}
                for name, metric in self.metrics.items():
                    metric_values_train[name] = metric(outputs, targets)

                progress.iter_done_report()

                self.epoch_flt = epoch_float(epoch, train_it + 1, self.train_iters)
                sly.report_metrics_training(self.epoch_flt, metric_values_train)

                if self.eval_planner.need_validation(self.epoch_flt):
                    metrics_values_val = self._validation()
                    self.eval_planner.validation_performed()

                    val_loss = metrics_values_val['loss']
                    model_is_best = val_loss < best_val_loss
                    if model_is_best:
                        best_val_loss = val_loss
                        logger.info('It\'s been determined that current model is the best one for a while.')

                    self._save_model_snapshot(model_is_best, opt_data={
                        'epoch': self.epoch_flt,
                        'val_metrics': metrics_values_val,
                    })

                    policy.reset_if_needed(val_loss, self.model)

            logger.info("Epoch was finished", extra={'epoch': self.epoch_flt})
示例#3
0
    def train(self):
        progress = sly.Progress('Model training: ',
                                self.epochs * self.train_iters)
        self.model.train()

        lr_decr = self.config['lr_decreasing']
        policy = LRPolicyWithPatience(optim_cls=Adam,
                                      init_lr=self.config['lr'],
                                      patience=lr_decr['patience'],
                                      lr_divisor=lr_decr['lr_divisor'],
                                      model=self.model)
        best_val_loss = float('inf')

        debug_saver = None
        debug_save_prob = float(os.getenv('DEBUG_PATCHES_PROB', 0.0))
        if debug_save_prob > 0:
            target_multi = int(255.0 / len(self.out_classes))
            debug_saver = DebugSaver(odir=os.path.join(sly.TaskPaths.DEBUG_DIR,
                                                       'debug_patches'),
                                     prob=debug_save_prob,
                                     target_multi=target_multi)

        for epoch in range(self.epochs):
            sly.logger.info("Before new epoch",
                            extra={'epoch': self.epoch_flt})

            for train_it, (inputs_cpu, targets_cpu) in enumerate(
                    self.data_loaders['train']):
                inputs, targets = cuda_variable(inputs_cpu), cuda_variable(
                    targets_cpu)
                outputs = self.model(inputs)
                loss = self.criterion(outputs, targets)

                if debug_saver is not None:
                    out_cls = functional.softmax(outputs, dim=1)
                    debug_saver.process(inputs_cpu, targets_cpu,
                                        out_cls.data.cpu())

                policy.optimizer.zero_grad()
                loss.backward()
                policy.optimizer.step()

                metric_values_train = {'loss': loss.data[0]}
                for name, metric in self.metrics.items():
                    metric_values_train[name] = metric(outputs, targets)

                progress.iter_done_report()

                self.epoch_flt = epoch_float(epoch, train_it + 1,
                                             self.train_iters)
                sly.report_metrics_training(self.epoch_flt,
                                            metric_values_train)

                if self.eval_planner.need_validation(self.epoch_flt):
                    metrics_values_val = self._validation()
                    self.eval_planner.validation_performed()

                    val_loss = metrics_values_val['loss']
                    model_is_best = val_loss < best_val_loss
                    if model_is_best:
                        best_val_loss = val_loss
                        sly.logger.info(
                            'It\'s been determined that current model is the best one for a while.'
                        )

                    self._save_model_snapshot(model_is_best,
                                              opt_data={
                                                  'epoch':
                                                  self.epoch_flt,
                                                  'val_metrics':
                                                  metrics_values_val,
                                              })

                    policy.reset_if_needed(val_loss, self.model)

            sly.logger.info("Epoch was finished",
                            extra={'epoch': self.epoch_flt})
示例#4
0
def train(datasets_dicts,
          epochs,
          val_every,
          iters_cnt,
          validate_with_eval_model,
          pipeline_config,
          num_clones=1,
          save_cback=None,
          is_transfer_learning=False):
    logger.info('Start train')
    configs = configs_from_pipeline(pipeline_config)

    model_config = configs['model']
    train_config = configs['train_config']

    create_model_fn = functools.partial(model_builder.build,
                                        model_config=model_config,
                                        is_training=True)
    detection_model = create_model_fn()

    def get_next(dataset):
        return dataset_util.make_initializable_iterator(
            build_dataset(dataset)).get_next()

    create_tensor_dict_fn = functools.partial(get_next,
                                              datasets_dicts['train'])
    create_tensor_dict_fn_val = functools.partial(get_next,
                                                  datasets_dicts['val'])

    data_augmentation_options = [
        preprocessor_builder.build(step)
        for step in train_config.data_augmentation_options
    ]

    with tf.Graph().as_default():
        # Build a configuration specifying multi-GPU and multi-replicas.
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=4,
            clone_on_cpu=False,
            replica_id=0,
            num_replicas=1,
            num_ps_tasks=0,
            worker_job_name='lonely_worker')

        # Place the global step on the device storing the variables.
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        with tf.device(deploy_config.inputs_device()):
            coord = coordinator.Coordinator()
            input_queue = create_input_queue(
                train_config.batch_size, create_tensor_dict_fn,
                train_config.batch_queue_capacity,
                train_config.num_batch_queue_threads,
                train_config.prefetch_queue_capacity,
                data_augmentation_options)

            input_queue_val = create_input_queue(
                train_config.batch_size, create_tensor_dict_fn_val,
                train_config.batch_queue_capacity,
                train_config.num_batch_queue_threads,
                train_config.prefetch_queue_capacity,
                data_augmentation_options)

        # create validation graph
        create_model_fn_val = functools.partial(
            model_builder.build,
            model_config=model_config,
            is_training=not validate_with_eval_model)

        with tf.device(deploy_config.optimizer_device()):
            training_optimizer, optimizer_summary_vars = optimizer_builder.build(
                train_config.optimizer)
            for var in optimizer_summary_vars:
                tf.summary.scalar(var.op.name, var, family='LearningRate')

        train_losses = []
        grads_and_vars = []
        with slim.arg_scope([slim.model_variable, slim.variable],
                            device='/device:CPU:0'):
            for curr_dev_id in range(num_clones):
                with tf.device('/gpu:{}'.format(curr_dev_id)):
                    with tf.name_scope(
                            'clone_{}'.format(curr_dev_id)) as scope:
                        with tf.variable_scope(
                                tf.get_variable_scope(),
                                reuse=True if curr_dev_id > 0 else None):
                            losses = _create_losses_val(
                                input_queue, create_model_fn, train_config)
                            clones_loss = tf.add_n(losses)
                            clones_loss = tf.divide(clones_loss,
                                                    1.0 * num_clones)
                            grads = training_optimizer.compute_gradients(
                                clones_loss)
                            train_losses.append(clones_loss)
                            grads_and_vars.append(grads)
                            if curr_dev_id == 0:
                                update_ops = tf.get_collection(
                                    tf.GraphKeys.UPDATE_OPS)

        val_total_loss = get_val_loss(num_clones, input_queue_val,
                                      create_model_fn_val, train_config)

        with tf.device(deploy_config.optimizer_device()):
            total_loss = tf.add_n(train_losses)
            grads_and_vars = model_deploy._sum_clones_gradients(grads_and_vars)
            total_loss = tf.check_numerics(total_loss,
                                           'LossTensor is inf or nan.')

            # Optionally multiply bias gradients by train_config.bias_grad_multiplier.
            if train_config.bias_grad_multiplier:
                biases_regex_list = ['.*/biases']
                grads_and_vars = variables_helper.multiply_gradients_matching_regex(
                    grads_and_vars,
                    biases_regex_list,
                    multiplier=train_config.bias_grad_multiplier)

            # Optionally freeze some layers by setting their gradients to be zero.
            if train_config.freeze_variables:
                grads_and_vars = variables_helper.freeze_gradients_matching_regex(
                    grads_and_vars, train_config.freeze_variables)

            # Optionally clip gradients
            if train_config.gradient_clipping_by_norm > 0:
                with tf.name_scope('clip_grads'):
                    grads_and_vars = slim.learning.clip_gradient_norms(
                        grads_and_vars, train_config.gradient_clipping_by_norm)

            # Create gradient updates.
            grad_updates = training_optimizer.apply_gradients(
                grads_and_vars, global_step=global_step)
            update_ops.append(grad_updates)
            update_op = tf.group(*update_ops, name='update_barrier')
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        config = tf.ConfigProto(allow_soft_placement=True,
                                log_device_placement=False)
        coord.clear_stop()
        sess = tf.Session(config=config)
        saver = tf.train.Saver()

        graph = ops.get_default_graph()
        with graph.as_default():
            with ops.name_scope('init_ops'):
                init_op = variables.global_variables_initializer()
                ready_op = variables.report_uninitialized_variables()
                local_init_op = control_flow_ops.group(
                    variables.local_variables_initializer(),
                    lookup_ops.tables_initializer())

        # graph.finalize()
        sess.run([init_op, ready_op, local_init_op])

        queue_runners = graph.get_collection(ops.GraphKeys.QUEUE_RUNNERS)
        threads = []
        for qr in queue_runners:
            threads.extend(
                qr.create_threads(sess, coord=coord, daemon=True, start=True))

        logger.info('Start restore')
        if train_config.fine_tune_checkpoint:
            var_map = detection_model.restore_map(
                fine_tune_checkpoint_type=train_config.
                fine_tune_checkpoint_type,
                load_all_detection_checkpoint_vars=(
                    train_config.load_all_detection_checkpoint_vars
                    and (not is_transfer_learning)))
            available_var_map = (
                variables_helper.get_variables_available_in_checkpoint(
                    var_map, train_config.fine_tune_checkpoint))
            if 'global_step' in available_var_map:
                del available_var_map['global_step']
            init_saver = tf.train.Saver(available_var_map)
            logger.info('Restoring model weights from previous checkpoint.')
            init_saver.restore(sess, train_config.fine_tune_checkpoint)
            logger.info('Model restored.')

        eval_planner = EvalPlanner(epochs, val_every)
        progress = sly.Progress('Model training: ',
                                epochs * iters_cnt['train'])
        best_val_loss = float('inf')
        epoch_flt = 0

        for epoch in range(epochs):
            logger.info("Before new epoch", extra={'epoch': epoch_flt})
            for train_it in range(iters_cnt['train']):
                total_loss, np_global_step = sess.run(
                    [train_tensor, global_step])

                metrics_values_train = {
                    'loss': total_loss,
                }

                progress.iter_done_report()
                epoch_flt = epoch_float(epoch, train_it + 1,
                                        iters_cnt['train'])
                sly.report_metrics_training(epoch_flt, metrics_values_train)

                if eval_planner.need_validation(epoch_flt):
                    logger.info("Before validation",
                                extra={'epoch': epoch_flt})

                    overall_val_loss = 0
                    for val_it in range(iters_cnt['val']):
                        overall_val_loss += sess.run(val_total_loss)

                        logger.info("Validation in progress",
                                    extra={
                                        'epoch': epoch_flt,
                                        'val_iter': val_it,
                                        'val_iters': iters_cnt['val']
                                    })

                    metrics_values_val = {
                        'loss': overall_val_loss / iters_cnt['val'],
                    }
                    sly.report_metrics_validation(epoch_flt,
                                                  metrics_values_val)
                    logger.info("Validation has been finished",
                                extra={'epoch': epoch_flt})

                    eval_planner.validation_performed()

                    val_loss = metrics_values_val['loss']
                    model_is_best = val_loss < best_val_loss
                    if model_is_best:
                        best_val_loss = val_loss
                        logger.info(
                            'It\'s been determined that current model is the best one for a while.'
                        )

                    save_cback(saver,
                               sess,
                               model_is_best,
                               opt_data={
                                   'epoch': epoch_flt,
                                   'val_metrics': metrics_values_val,
                               })

            logger.info("Epoch was finished", extra={'epoch': epoch_flt})
        coord.request_stop()
        coord.join(threads)
示例#5
0
    def train(self):
        self.device_ids = sly.env.remap_gpu_devices(self.config['gpu_devices'])

        epochs = self.config['epochs']
        session_config = tf.ConfigProto(allow_soft_placement=True)
        # with tf.Graph().as_default(), tf.Session(config=session_config) as session:
        #     with tf.device('/gpu:0'):
        self.session = tf.Session(config=session_config)
        opt = tf.train.AdamOptimizer(self.config['lr'])
        train_op = opt.minimize(self.loss)
        init_op = tf.global_variables_initializer()
        self.session.run(init_op)

        if sly.fs.dir_empty(
                sly.TaskPaths.MODEL_DIR):  # @TODO: implement transfer learning
            sly.logger.info('Weights were inited randomly.')
        elif self.config['weights_init_type'] == TRANSFER_LEARNING:
            vars_to_restore = slim.get_variables_to_restore()
            variables_to_restore = [
                v for v in vars_to_restore
                if ('Adam' not in v.name and '_power' not in v.name
                    and 'logits' not in v.name)
            ]
            re_saver = tf.train.Saver(variables_to_restore, max_to_keep=0)
            re_saver.restore(
                self.session,
                os.path.join(sly.TaskPaths.MODEL_DIR, 'model_weights',
                             'model.ckpt'))
        elif self.config['weights_init_type'] == CONTINUE_TRAINING:
            re_saver = tf.train.Saver(max_to_keep=0)
            re_saver.restore(
                self.session,
                os.path.join(sly.TaskPaths.MODEL_DIR, 'model_weights',
                             'model.ckpt'))
            sly.logger.info('Restored model weights from training')

        eval_planner = EvalPlanner(epochs, self.config['val_every'])
        progress = sly.Progress('Model training: ', epochs * self.train_iters)
        best_val_loss = float('inf')
        self.saver = tf.train.Saver(max_to_keep=0)

        for epoch in range(epochs):
            sly.logger.info("Before new epoch",
                            extra={'epoch': self.epoch_flt})
            for train_it, (batch_inputs, batch_targets) in enumerate(
                    self.data_loaders['train']):
                feed = {self.inputs: batch_inputs, self.labels: batch_targets}

                tl, _ = self.session.run([self.loss, train_op], feed)

                metrics_values_train = {
                    'loss': tl,
                }

                progress.iter_done_report()
                self.epoch_flt = epoch_float(epoch, train_it + 1,
                                             len(self.data_loaders['train']))
                sly.report_metrics_training(self.epoch_flt,
                                            metrics_values_train)

                if eval_planner.need_validation(self.epoch_flt):
                    sly.logger.info("Before validation",
                                    extra={'epoch': self.epoch_flt})

                    val_metrics_values = self._validation(self.session)
                    eval_planner.validation_performed()

                    val_loss = val_metrics_values['loss']

                    model_is_best = val_loss < best_val_loss
                    if model_is_best:
                        best_val_loss = val_loss
                        sly.logger.info(
                            'It\'s been determined that current model is the best one for a while.'
                        )

                    self._save_model_snapshot(model_is_best,
                                              opt_data={
                                                  'epoch':
                                                  self.epoch_flt,
                                                  'val_metrics':
                                                  val_metrics_values,
                                              })