Пример #1
0
    def train(self):
        # Start queue threads.
        threads = tf.train.start_queue_runners(coord=self.coord, sess=self.sess)

        progress = sly.progress_counter_train(self.epochs, self.iters_cnt['train'])
        best_val_loss = float('inf')
        internal_step = 0

        for epoch in range(self.epochs):
            logger.info("Before new epoch", extra={'epoch': self.epoch_flt})

            for train_it in range(self.iters_cnt['train']):
                feed_dict = {self.step_ph: internal_step}
                internal_step += 1
                loss_value, _ = self.sess.run([self.total_loss, self.train_op], feed_dict=feed_dict)

                self.sess.run(self.running_vars_initializer)
                # Update the running variables on new batch of samples
                plh_label = self.sess.run(self.v1)
                plh_prediction = self.sess.run(self.v2)
                feed_dict = {self.tf_label: plh_label, self.tf_prediction: plh_prediction}
                train_accuracy = self.sess.run(self.tf_metric_update, feed_dict=feed_dict)  # from last GPU

                metrics_values_train = {
                    'loss': loss_value,
                    'accuracy': train_accuracy
                }

                progress.iter_done_report()
                self.epoch_flt = epoch_float(epoch, train_it + 1, self.iters_cnt['train'])
                sly.report_metrics_training(self.epoch_flt, metrics_values_train)

                if self.eval_planner.need_validation(self.epoch_flt):
                    logger.info("Before validation", extra={'epoch': self.epoch_flt})

                    overall_val_loss = 0
                    overall_val_accuracy = 0
                    for val_it in range(self.iters_cnt['val']):
                        overall_val_loss += self.sess.run(self.total_val_loss)
                        self.sess.run(self.running_vars_initializer)
                        # Update the running variables on new batch of samples
                        plh_label = self.sess.run(self.v1_val)
                        plh_prediction = self.sess.run(self.v2_val)
                        feed_dict = {self.tf_label: plh_label, self.tf_prediction: plh_prediction}
                        overall_val_accuracy += self.sess.run(self.tf_metric_update, feed_dict=feed_dict)

                        logger.info("Validation in progress", extra={'epoch': self.epoch_flt,
                                                                     'val_iter': val_it,
                                                                     'val_iters': self.iters_cnt['val']})

                    metrics_values_val = {
                        'loss': overall_val_loss / self.iters_cnt['val'],
                        'accuracy': overall_val_accuracy / self.iters_cnt['val']
                    }
                    sly.report_metrics_validation(self.epoch_flt, metrics_values_val)
                    logger.info("Validation has been finished", extra={'epoch': self.epoch_flt})

                    self.eval_planner.validation_performed()

                    val_loss = metrics_values_val['loss']
                    model_is_best = val_loss < best_val_loss
                    if model_is_best:
                        best_val_loss = val_loss
                        logger.info('It\'s been determined that current model is the best one for a while.')

                    self._dump_model(model_is_best, opt_data={
                        'epoch': self.epoch_flt,
                        'val_metrics': metrics_values_val,
                    })

            logger.info("Epoch was finished", extra={'epoch': self.epoch_flt})

        self.coord.request_stop()
        self.coord.join(threads)
Пример #2
0
def train(data_dicts,
          class_num,
          input_size,
          lr,
          n_epochs,
          num_clones,
          iters_cnt,
          val_every,
          model_init_fn,
          save_cback,
          atrous_rates=[6, 12, 18],
          fine_tune_batch_norm=True,
          output_stride=16):
    tf.logging.set_verbosity(tf.logging.INFO)

    # Set up deployment (i.e., multi-GPUs and/or multi-replicas).
    config = model_deploy.DeploymentConfig(num_clones=num_clones,
                                           clone_on_cpu=clone_on_cpu,
                                           replica_id=task,
                                           num_replicas=num_replicas,
                                           num_ps_tasks=num_ps_tasks)

    with tf.Graph().as_default():
        with tf.device(config.inputs_device()):
            samples = get(data_dicts['train'],
                          input_size,
                          is_training=True,
                          model_variant=model_variant)
            samples_val = get(data_dicts['val'],
                              input_size,
                              is_training=True,
                              model_variant=model_variant)

        inputs_queue = prefetch_queue.prefetch_queue(samples,
                                                     capacity=128 *
                                                     config.num_clones,
                                                     dynamic_pad=True)
        inputs_queue_val = prefetch_queue.prefetch_queue(samples_val,
                                                         capacity=128 *
                                                         config.num_clones,
                                                         dynamic_pad=True)
        coord = tf.train.Coordinator()

        # Create the global step on the device storing the variables.
        with tf.device(config.variables_device()):
            global_step = tf.train.create_global_step()

            # Define the model and create clones.
            model_fn = _build_deeplab
            model_args = (inputs_queue, {
                'semantic': class_num
            }, input_size, atrous_rates, output_stride, fine_tune_batch_norm)
            clones = model_deploy.create_clones(config,
                                                model_fn,
                                                args=model_args)

            # Gather update_ops from the first clone. These contain, for example,
            # the updates for the batch_norm variables created by model_fn.
            first_clone_scope = config.clone_scope(0)
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                           first_clone_scope)

        # Build the optimizer based on the device specification.
        with tf.device(config.optimizer_device()):
            learning_rate = lr
            optimizer = tf.train.AdamOptimizer(learning_rate)

        with tf.device(config.variables_device()):
            total_loss, grads_and_vars = model_deploy.optimize_clones(
                clones, optimizer)
            total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')

            model_fn_val = _build_deeplab_val
            model_args_val = (inputs_queue_val, {
                'semantic': class_num
            }, input_size, atrous_rates, output_stride)
            val_clones, val_losses = create_val_clones(num_clones,
                                                       config,
                                                       model_fn_val,
                                                       args=model_args_val)
            val_total_loss = get_clones_val_losses(val_clones, None,
                                                   val_losses)
            # Modify the gradients for biases and last layer variables.
            last_layers = model.get_extra_layer_scopes()
            grad_mult = train_utils.get_model_gradient_multipliers(
                last_layers, last_layer_gradient_multiplier)
            if grad_mult:
                grads_and_vars = slim.learning.multiply_gradients(
                    grads_and_vars, grad_mult)

            # Create gradient update op.
            grad_updates = optimizer.apply_gradients(grads_and_vars,
                                                     global_step=global_step)
            update_ops.append(grad_updates)
            update_op = tf.group(*update_ops)
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        config = tf.ConfigProto(allow_soft_placement=True,
                                log_device_placement=False)
        coord.clear_stop()
        sess = tf.Session(config=config)

        graph = ops.get_default_graph()
        with graph.as_default():
            with ops.name_scope('init_ops'):
                init_op = variables.global_variables_initializer()
                ready_op = variables.report_uninitialized_variables()
                local_init_op = control_flow_ops.group(
                    variables.local_variables_initializer(),
                    lookup_ops.tables_initializer())

        # graph.finalize()
        sess.run([init_op, ready_op, local_init_op])
        queue_runners = graph.get_collection(ops.GraphKeys.QUEUE_RUNNERS)
        threads = []
        for qr in queue_runners:
            threads.extend(
                qr.create_threads(sess, coord=coord, daemon=True, start=True))

        # # for i in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
        # #     print(i)
        # vary_23 = [v for v in tf.global_variables() if v.name == 'xception_65/middle_flow/block1/unit_8/xception_module/separable_conv3_depthwise/BatchNorm/moving_mean:0'][0]
        #
        # beta_23 = [v for v in tf.global_variables() if v.name == 'xception_65/middle_flow/block1/unit_8/xception_module/separable_conv3_depthwise/BatchNorm/gamma:0'][0]
        # for i in range(1000):
        #     train_loss = sess.run(train_tensor)
        #     print(train_loss)
        #     vary, beta = sess.run([vary_23, beta_23])
        #     print('mean', vary[0:3])
        #     print('beta', beta[0:3])
        #     if (i + 1) % 10 == 0:
        #         for i in range(10):
        #             val_loss = sess.run(val_total_loss)
        #             vary, beta = sess.run([vary_23, beta_23])
        #             print('mean val', vary[0:3])
        #             print('beta', beta[0:3])
        #             print('VAl_loss', val_loss)

        model_init_fn(sess)
        saver = tf.train.Saver()
        eval_planner = EvalPlanner(n_epochs, val_every)
        progress = sly.progress_counter_train(n_epochs, iters_cnt['train'])
        best_val_loss = float('inf')
        epoch_flt = 0

        for epoch in range(n_epochs):
            logger.info("Before new epoch", extra={'epoch': epoch_flt})
            for train_it in range(iters_cnt['train']):
                total_loss = sess.run(train_tensor)

                metrics_values_train = {
                    'loss': total_loss,
                }

                progress.iter_done_report()
                epoch_flt = epoch_float(epoch, train_it + 1,
                                        iters_cnt['train'])
                sly.report_metrics_training(epoch_flt, metrics_values_train)

                if eval_planner.need_validation(epoch_flt):
                    logger.info("Before validation",
                                extra={'epoch': epoch_flt})

                    overall_val_loss = 0
                    for val_it in range(iters_cnt['val']):
                        overall_val_loss += sess.run(val_total_loss)

                        logger.info("Validation in progress",
                                    extra={
                                        'epoch': epoch_flt,
                                        'val_iter': val_it,
                                        'val_iters': iters_cnt['val']
                                    })

                    metrics_values_val = {
                        'loss': overall_val_loss / iters_cnt['val'],
                    }
                    sly.report_metrics_validation(epoch_flt,
                                                  metrics_values_val)
                    logger.info("Validation has been finished",
                                extra={'epoch': epoch_flt})

                    eval_planner.validation_performed()

                    val_loss = metrics_values_val['loss']
                    model_is_best = val_loss < best_val_loss
                    if model_is_best:
                        best_val_loss = val_loss
                        logger.info(
                            'It\'s been determined that current model is the best one for a while.'
                        )

                    save_cback(saver,
                               sess,
                               model_is_best,
                               opt_data={
                                   'epoch': epoch_flt,
                                   'val_metrics': metrics_values_val,
                               })

            logger.info("Epoch was finished", extra={'epoch': epoch_flt})
Пример #3
0
def main():
    # Please note that auxiliary methods from sly (supervisely_lib) use supervisely_lib.logger to format output.
    # So don't replace formatters or handlers of the logger.
    # One may use other loggers or simple prints for other output, but it's recommended to use supervisely_lib.logger.
    logger.info('Hello ML world')
    print('Glad to see u')

    # TaskHelperTrain contains almost all needed to run training as Supervisely task,
    # including task settings and paths to data and models.
    task_helper = sly.TaskHelperTrain()

    # All settings and parameters are passed to task in json file.
    # Content of the file is entirely dependent on model implementation.
    training_settings = task_helper.task_settings
    logger.info('Task settings are read',
                extra={'task_settings': training_settings})
    cnt_epochs = training_settings[
        "epochs"]  # in the fake model we want cnt of epochs
    cnt_iters_per_epoch = training_settings["iters_per_epoch"]

    # Let's imitate model weights loading.
    # Task acquires directory with input model weights (e.g. to continue training or to initialize some parameters).
    # Content of the directory is entirely dependent on model implementation.
    model_dir = task_helper.paths.model_dir
    if task_helper.model_dir_is_empty():
        model = create_fake_model()
        logger.info('Model created from scratch')
    else:
        model = load_fake_model(model_dir)
        logger.info('Init model weights are loaded',
                    extra={'model_dir': model_dir})

    # We will save weights of trained model (checkpoints) into directories provided by the checkpoints_saver.
    checkpoints_saver = task_helper.checkpoints_saver
    logger.info('Ready to save checkpoints',
                extra={'results_dir': task_helper.paths.results_dir})

    # Let's imitate reading input project with training data.
    # Of course in real implementations it is usually wrapped in some data loaders which are executed in parallel.
    project_meta = task_helper.in_project_meta  # Project meta contains list of project classes.
    project_dir = task_helper.paths.project_dir
    project_fs = sly.ProjectFS.from_disk_dir_project(project_dir)
    # ProjectFS enlists all samples (image/annotation pairs) in input project.
    for item_descr in project_fs:
        logger.info('Processing input sample',
                    extra={
                        'dataset': item_descr.ds_name,
                        'image_name': item_descr.image_name
                    })

        # Open some image...
        img = cv2.imread(item_descr.img_path)
        logger.info('Read image from input project',
                    extra={
                        'width': img.shape[1],
                        'height': img.shape[0]
                    })

        # And read corresponding annotation...
        ann_packed = sly.json_load(item_descr.ann_path)
        ann = sly.Annotation.from_packed(ann_packed, project_meta)
        logger.info('Read annotation from input project',
                    extra={
                        'object_cnt': len(ann['objects']),
                        'tags': ann['tags']
                    })

    # We are to report progress of task over sly.ProgressCounter if we want to observe the progress in web panel.
    # In fact one task may report progress for some sequential (not nested) subtasks,
    # but here we will report training progress only.
    progress = sly.progress_counter_train(cnt_epochs, cnt_iters_per_epoch)

    epoch_flt = 0
    for epoch in range(cnt_epochs):
        logger.info("Epoch started", extra={'epoch': epoch})

        for train_iter in range(cnt_iters_per_epoch):
            logger.info('Some forward-backward pass...')
            time.sleep(1)

            progress.iter_done_report(
            )  # call it after every iteration to report progress
            epoch_flt = sly.epoch_float(epoch, train_iter + 1,
                                        cnt_iters_per_epoch)

            # And we are to report some metrics if we want to observe those amazing charts in web panel.
            # Regrettably, only the fixed metric types may be displayed now: 'loss', 'accuracy' and 'dice'.
            metric_values_train = {
                'loss': random.random(),
                'my_metric': random.uniform(0, 100)
            }
            sly.report_metrics_training(epoch_flt, metric_values_train)

        logger.info("Epoch finished", extra={'epoch': epoch})

        # Validation is not necessary but may be performed. So let's imitate validation...
        logger.info("Validation...")
        time.sleep(2)
        # Metrics for validation may also be reported.
        metric_values_val = {
            'loss': random.random(),
            'my_metric': random.uniform(0, 20)
        }
        sly.report_metrics_validation(epoch_flt, metric_values_val)

        # Save trained model weights when you want.
        # Model weights (checkpoint) should be written into directory provided by sly checkpoints_saver.
        # Content of the directory is entirely dependent on model implementation.
        cur_checkpoint_dir = checkpoints_saver.get_dir_to_write()
        dump_fake_model(cur_checkpoint_dir, model)
        checkpoints_saver.saved(is_best=True,
                                optional_data={
                                    'epoch': epoch_flt,
                                    'val_metrics': metric_values_val
                                })
        # It is necessary to call checkpoints_saver.saved after saving.
        # By default new model will be created from the best checkpoint over the whole training
        # (which is determined by "is_best" flag).
        # Some optional info may be provided. It will be linked with the checkpoint
        # and may help to distinguish checkpoints from same training.

    # Thank you for your patience.
    logger.info('Training finished')
Пример #4
0
def train(datasets_dicts,
          epochs,
          val_every,
          iters_cnt,
          validate_with_eval_model,
          pipeline_config,
          num_clones=1,
          save_cback=None):
    logger.info('Start train')
    configs = configs_from_pipeline(pipeline_config)

    model_config = configs['model']
    train_config = configs['train_config']

    create_model_fn = functools.partial(model_builder.build,
                                        model_config=model_config,
                                        is_training=True)
    detection_model = create_model_fn()

    def get_next(dataset):
        return dataset_util.make_initializable_iterator(
            build_dataset(dataset)).get_next()

    create_tensor_dict_fn = functools.partial(get_next,
                                              datasets_dicts['train'])
    create_tensor_dict_fn_val = functools.partial(get_next,
                                                  datasets_dicts['val'])

    data_augmentation_options = [
        preprocessor_builder.build(step)
        for step in train_config.data_augmentation_options
    ]

    with tf.Graph().as_default():
        # Build a configuration specifying multi-GPU and multi-replicas.
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=4,
            clone_on_cpu=False,
            replica_id=0,
            num_replicas=1,
            num_ps_tasks=0,
            worker_job_name='lonely_worker')

        # Place the global step on the device storing the variables.
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        with tf.device(deploy_config.inputs_device()):
            coord = coordinator.Coordinator()
            input_queue = create_input_queue(
                train_config.batch_size, create_tensor_dict_fn,
                train_config.batch_queue_capacity,
                train_config.num_batch_queue_threads,
                train_config.prefetch_queue_capacity,
                data_augmentation_options)

            input_queue_val = create_input_queue(
                train_config.batch_size, create_tensor_dict_fn_val,
                train_config.batch_queue_capacity,
                train_config.num_batch_queue_threads,
                train_config.prefetch_queue_capacity,
                data_augmentation_options)

        # create validation graph
        create_model_fn_val = functools.partial(
            model_builder.build,
            model_config=model_config,
            is_training=not validate_with_eval_model)

        with tf.device(deploy_config.optimizer_device()):
            training_optimizer, optimizer_summary_vars = optimizer_builder.build(
                train_config.optimizer)
            for var in optimizer_summary_vars:
                tf.summary.scalar(var.op.name, var, family='LearningRate')

        train_losses = []
        grads_and_vars = []
        with slim.arg_scope([slim.model_variable, slim.variable],
                            device='/device:CPU:0'):
            for curr_dev_id in range(num_clones):
                with tf.device('/gpu:{}'.format(curr_dev_id)):
                    with tf.name_scope(
                            'clone_{}'.format(curr_dev_id)) as scope:
                        with tf.variable_scope(
                                tf.get_variable_scope(),
                                reuse=True if curr_dev_id > 0 else None):
                            losses = _create_losses_val(
                                input_queue, create_model_fn, train_config)
                            clones_loss = tf.add_n(losses)
                            clones_loss = tf.divide(clones_loss,
                                                    1.0 * num_clones)
                            grads = training_optimizer.compute_gradients(
                                clones_loss)
                            train_losses.append(clones_loss)
                            grads_and_vars.append(grads)
                            if curr_dev_id == 0:
                                update_ops = tf.get_collection(
                                    tf.GraphKeys.UPDATE_OPS)

        val_total_loss = get_val_loss(num_clones, input_queue_val,
                                      create_model_fn_val, train_config)

        with tf.device(deploy_config.optimizer_device()):
            total_loss = tf.add_n(train_losses)
            grads_and_vars = model_deploy._sum_clones_gradients(grads_and_vars)
            total_loss = tf.check_numerics(total_loss,
                                           'LossTensor is inf or nan.')

            # Optionally multiply bias gradients by train_config.bias_grad_multiplier.
            if train_config.bias_grad_multiplier:
                biases_regex_list = ['.*/biases']
                grads_and_vars = variables_helper.multiply_gradients_matching_regex(
                    grads_and_vars,
                    biases_regex_list,
                    multiplier=train_config.bias_grad_multiplier)

            # Optionally freeze some layers by setting their gradients to be zero.
            if train_config.freeze_variables:
                grads_and_vars = variables_helper.freeze_gradients_matching_regex(
                    grads_and_vars, train_config.freeze_variables)

            # Optionally clip gradients
            if train_config.gradient_clipping_by_norm > 0:
                with tf.name_scope('clip_grads'):
                    grads_and_vars = slim.learning.clip_gradient_norms(
                        grads_and_vars, train_config.gradient_clipping_by_norm)

            # Create gradient updates.
            grad_updates = training_optimizer.apply_gradients(
                grads_and_vars, global_step=global_step)
            update_ops.append(grad_updates)
            update_op = tf.group(*update_ops, name='update_barrier')
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        config = tf.ConfigProto(allow_soft_placement=True,
                                log_device_placement=False)
        coord.clear_stop()
        sess = tf.Session(config=config)
        saver = tf.train.Saver()

        graph = ops.get_default_graph()
        with graph.as_default():
            with ops.name_scope('init_ops'):
                init_op = variables.global_variables_initializer()
                ready_op = variables.report_uninitialized_variables()
                local_init_op = control_flow_ops.group(
                    variables.local_variables_initializer(),
                    lookup_ops.tables_initializer())

        # graph.finalize()
        sess.run([init_op, ready_op, local_init_op])

        queue_runners = graph.get_collection(ops.GraphKeys.QUEUE_RUNNERS)
        threads = []
        for qr in queue_runners:
            threads.extend(
                qr.create_threads(sess, coord=coord, daemon=True, start=True))

        logger.info('Start restore')
        if train_config.fine_tune_checkpoint:
            var_map = detection_model.restore_map(
                fine_tune_checkpoint_type=train_config.
                fine_tune_checkpoint_type,
                load_all_detection_checkpoint_vars=(
                    train_config.load_all_detection_checkpoint_vars))
            available_var_map = (
                variables_helper.get_variables_available_in_checkpoint(
                    var_map, train_config.fine_tune_checkpoint))
            if 'global_step' in available_var_map:
                del available_var_map['global_step']
            init_saver = tf.train.Saver(available_var_map)
            logger.info('Restoring model weights from previous checkpoint.')
            init_saver.restore(sess, train_config.fine_tune_checkpoint)
            logger.info('Model restored.')

        eval_planner = EvalPlanner(epochs, val_every)
        progress = sly.progress_counter_train(epochs, iters_cnt['train'])
        best_val_loss = float('inf')
        epoch_flt = 0

        for epoch in range(epochs):
            logger.info("Before new epoch", extra={'epoch': epoch_flt})
            for train_it in range(iters_cnt['train']):
                total_loss, np_global_step = sess.run(
                    [train_tensor, global_step])

                metrics_values_train = {
                    'loss': total_loss,
                }

                progress.iter_done_report()
                epoch_flt = epoch_float(epoch, train_it + 1,
                                        iters_cnt['train'])
                sly.report_metrics_training(epoch_flt, metrics_values_train)

                if eval_planner.need_validation(epoch_flt):
                    logger.info("Before validation",
                                extra={'epoch': epoch_flt})

                    overall_val_loss = 0
                    for val_it in range(iters_cnt['val']):
                        overall_val_loss += sess.run(val_total_loss)

                        logger.info("Validation in progress",
                                    extra={
                                        'epoch': epoch_flt,
                                        'val_iter': val_it,
                                        'val_iters': iters_cnt['val']
                                    })

                    metrics_values_val = {
                        'loss': overall_val_loss / iters_cnt['val'],
                    }
                    sly.report_metrics_validation(epoch_flt,
                                                  metrics_values_val)
                    logger.info("Validation has been finished",
                                extra={'epoch': epoch_flt})

                    eval_planner.validation_performed()

                    val_loss = metrics_values_val['loss']
                    model_is_best = val_loss < best_val_loss
                    if model_is_best:
                        best_val_loss = val_loss
                        logger.info(
                            'It\'s been determined that current model is the best one for a while.'
                        )

                    save_cback(saver,
                               sess,
                               model_is_best,
                               opt_data={
                                   'epoch': epoch_flt,
                                   'val_metrics': metrics_values_val,
                               })

            logger.info("Epoch was finished", extra={'epoch': epoch_flt})
        coord.request_stop()
        coord.join(threads)
Пример #5
0
    def train(self):
        progress = sly.progress_counter_train(self.epochs, self.train_iters)
        self.model.train()

        lr_decr = self.config['lr_decreasing']
        policy = LRPolicyWithPatience(optim_cls=Adam,
                                      init_lr=self.config['lr'],
                                      patience=lr_decr['patience'],
                                      lr_divisor=lr_decr['lr_divisor'],
                                      model=self.model)
        best_val_loss = float('inf')

        debug_saver = None
        debug_save_prob = float(os.getenv('DEBUG_PATCHES_PROB', 0.0))
        if debug_save_prob > 0:
            target_multi = int(255.0 / len(self.out_classes))
            debug_saver = DebugSaver(odir=osp.join(self.helper.paths.debug_dir,
                                                   'debug_patches'),
                                     prob=debug_save_prob,
                                     target_multi=target_multi)

        for epoch in range(self.epochs):
            logger.info("Before new epoch", extra={'epoch': self.epoch_flt})

            for train_it, (inputs_cpu, targets_cpu) in enumerate(
                    self.data_loaders['train']):
                inputs, targets = cuda_variable(inputs_cpu), cuda_variable(
                    targets_cpu)
                outputs = self.model(inputs)
                loss = self.criterion(outputs, targets)

                if debug_saver is not None:
                    out_cls = functional.softmax(outputs, dim=1)
                    debug_saver.process(inputs_cpu, targets_cpu,
                                        out_cls.data.cpu())

                policy.optimizer.zero_grad()
                loss.backward()
                policy.optimizer.step()

                metric_values_train = {'loss': loss.data[0]}
                for name, metric in self.metrics.items():
                    metric_values_train[name] = metric(outputs, targets)

                progress.iter_done_report()

                self.epoch_flt = epoch_float(epoch, train_it + 1,
                                             self.train_iters)
                sly.report_metrics_training(self.epoch_flt,
                                            metric_values_train)

                if self.eval_planner.need_validation(self.epoch_flt):
                    metrics_values_val = self._validation()
                    self.eval_planner.validation_performed()

                    val_loss = metrics_values_val['loss']
                    model_is_best = val_loss < best_val_loss
                    if model_is_best:
                        best_val_loss = val_loss
                        logger.info(
                            'It\'s been determined that current model is the best one for a while.'
                        )

                    self._dump_model(model_is_best,
                                     opt_data={
                                         'epoch': self.epoch_flt,
                                         'val_metrics': metrics_values_val,
                                     })

                    policy.reset_if_needed(val_loss, self.model)

            logger.info("Epoch was finished", extra={'epoch': self.epoch_flt})