def train_epoch_end(config, compression_algo, net, epoch, iteration, epoch_size, lr_scheduler, optimizer,
                    test_data_loader):
    test_freq_in_epochs = max(config.test_interval // epoch_size, 1)
    compression_algo.scheduler.epoch_step(epoch)
    if not isinstance(lr_scheduler, ReduceLROnPlateau):
        lr_scheduler.step(epoch)
    if epoch % test_freq_in_epochs == 0 and iteration != 0:
        if is_on_first_rank(config):
            print_statistics(compression_algo.statistics())
        with torch.no_grad():
            net.eval()
            mAP = test_net(net, config.device, test_data_loader, distributed=config.multiprocessing_distributed)
            if isinstance(lr_scheduler, ReduceLROnPlateau):
                lr_scheduler.step(mAP)
            net.train()
    if epoch > 0 and epoch % config.save_freq == 0 and is_on_first_rank(config):
        print('Saving state, iter:', iteration)
        checkpoint_file_path = osp.join(config.intermediate_checkpoints_path,
                                        "{}_{}.pth".format(config.model, iteration))
        torch.save({
            'state_dict': net.state_dict(),
            'optimizer': optimizer.state_dict(),
            'iter': iteration,
            'scheduler': compression_algo.scheduler.state_dict()
        }, str(checkpoint_file_path))
def train_epoch_end(config, compression_algo, net, epoch, iteration,
                    epoch_size, lr_scheduler, optimizer, test_data_loader,
                    best_mAp):
    is_best = False
    test_freq_in_epochs = max(config.test_interval // epoch_size, 1)
    compression_algo.scheduler.epoch_step(epoch)
    if not isinstance(lr_scheduler, ReduceLROnPlateau):
        lr_scheduler.step(epoch)
    if epoch % test_freq_in_epochs == 0 and iteration != 0:
        if is_on_first_rank(config):
            print_statistics(compression_algo.statistics())
        with torch.no_grad():
            net.eval()
            mAP = test_net(net,
                           config.device,
                           test_data_loader,
                           distributed=config.multiprocessing_distributed)
            if mAP > best_mAp:
                is_best = True
                best_mAp = mAP
            if config.metrics_dump is not None:
                write_metrics(mAP, config)
            if isinstance(lr_scheduler, ReduceLROnPlateau):
                lr_scheduler.step(mAP)
            net.train()
    if is_on_first_rank(config):
        checkpoint_file_path = osp.join(config.checkpoint_save_dir,
                                        "{}_last.pth".format(get_name(config)))
        torch.save(
            {
                'state_dict': net.state_dict(),
                'optimizer': optimizer.state_dict(),
                'iter': iteration,
                'scheduler': compression_algo.scheduler.state_dict()
            }, str(checkpoint_file_path))
        make_additional_checkpoints(checkpoint_file_path,
                                    is_best=is_best,
                                    epoch=epoch + 1,
                                    config=config)
    return best_mAp
def main_worker(current_gpu, config):
    #################################
    # Setup experiment environment
    #################################
    config.current_gpu = current_gpu
    config.distributed = config.execution_mode in (ExecutionMode.DISTRIBUTED, ExecutionMode.MULTIPROCESSING_DISTRIBUTED)
    if config.distributed:
        configure_distributed(config)
    if is_on_first_rank(config):
        configure_logging(config)
        print_args(config)

    config.device = get_device(config)
    config.start_iter = 0

    ##########################
    # Prepare metrics log file
    ##########################

    if config.metrics_dump and config.resuming_checkpoint is not None:
        avg = 0
        metrics = {os.path.basename(config.resuming_checkpoint): avg}
        write_metrics(config, metrics)

    ##################
    # Prepare model
    ##################

    compression_algo, net = create_model(config)
    if config.distributed:
        config.batch_size //= config.ngpus_per_node
        config.workers //= config.ngpus_per_node
        compression_algo.distributed()

    ###########################
    # Criterion and optimizer
    ###########################

    params_to_optimize = get_parameter_groups(net, config)
    optimizer, lr_scheduler = make_optimizer(params_to_optimize, config)

    criterion = MultiBoxLoss(
        config,
        config['num_classes'],
        overlap_thresh=0.5,
        prior_for_matching=True,
        bkg_label=0,
        neg_mining=True,
        neg_pos=3,
        neg_overlap=0.5,
        encode_target=False,
        device=config.device
    )

    ###########################
    # Load checkpoint
    ###########################

    resuming_checkpoint = config.resuming_checkpoint
    if resuming_checkpoint:
        print('Resuming training, loading {}...'.format(resuming_checkpoint))
        checkpoint = torch.load(resuming_checkpoint, map_location='cpu')
        # use checkpoint itself in case of only state dict is saved
        # i.e. checkpoint is created with `torch.save(module.state_dict())`
        state_dict = checkpoint.get('state_dict', checkpoint)
        load_state(net, state_dict, is_resume=True)
        if config.mode.lower() == 'train' and config.to_onnx is None:
            compression_algo.scheduler.load_state_dict(checkpoint['scheduler'])
            optimizer.load_state_dict(checkpoint.get('optimizer', optimizer.state_dict()))
            config.start_iter = checkpoint.get('iter', 0) + 1

    if config.to_onnx:
        compression_algo.export_model(config.to_onnx)
        print("Saved to {}".format(config.to_onnx))
        return

    ###########################
    # Prepare data
    ###########################

    test_data_loader, train_data_loader = create_dataloaders(config)

    if config.mode.lower() == 'test':
        with torch.no_grad():
            print_statistics(compression_algo.statistics())
            net.eval()
            mAp = test_net(net, config.device, test_data_loader, distributed=config.distributed)
            if config.metrics_dump and config.resuming_checkpoint is not None:
                avg = mAp*100
                metrics = {os.path.basename(config.resuming_checkpoint): round(avg, 2)}
                write_metrics(config, metrics)
            return

    if not resuming_checkpoint:
        compression_algo.initialize(train_data_loader)

    train(net, compression_algo, train_data_loader, test_data_loader, criterion, optimizer, config, lr_scheduler)
def train(net, compression_algo, train_data_loader, test_data_loader, criterion, optimizer, config, lr_scheduler):
    net.train()
    # loss counters
    loc_loss = 0  # epoch
    conf_loss = 0

    epoch_size = len(train_data_loader)
    print('Training ', config.model, ' on ', train_data_loader.dataset.name, ' dataset...')
    batch_iterator = None

    t_start = time.time()
    print_statistics(compression_algo.statistics())

    for iteration in range(config.start_iter, config['max_iter']):
        if (not batch_iterator) or (iteration % epoch_size == 0):
            # create batch iterator
            batch_iterator = iter(train_data_loader)

        epoch = iteration // epoch_size
        if iteration % epoch_size == 0:
            train_epoch_end(config, compression_algo, net, epoch, iteration, epoch_size, lr_scheduler, optimizer,
                            test_data_loader)

        compression_algo.scheduler.step(iteration - config.start_iter)

        optimizer.zero_grad()
        batch_iterator, batch_loss, batch_loss_c, batch_loss_l, loss_comp = train_step(
            batch_iterator, compression_algo, config, criterion, net, train_data_loader
        )
        optimizer.step()
        batch_loss_l = batch_loss_l / config.iter_size
        batch_loss_c = batch_loss_c / config.iter_size
        model_loss = (batch_loss_l + batch_loss_c) / config.iter_size
        batch_loss = batch_loss / config.iter_size

        loc_loss += batch_loss_l.item()
        conf_loss += batch_loss_c.item()

        ###########################
        # Logging
        ###########################

        if is_on_first_rank(config):
            config.tb.add_scalar("train/loss_l", batch_loss_l.item(), iteration)
            config.tb.add_scalar("train/loss_c", batch_loss_c.item(), iteration)
            config.tb.add_scalar("train/loss", batch_loss.item(), iteration)

            checkpoint_file_path = osp.join(config.checkpoint_save_dir, "{}_last.pth".format(get_name(config)))
            torch.save({
                'state_dict': net.state_dict(),
                'optimizer': optimizer.state_dict(),
                'iter': config['max_iter'],
                'scheduler': compression_algo.scheduler.state_dict()
            }, str(checkpoint_file_path))
            make_additional_checkpoints(checkpoint_file_path,
                                        is_best=True,
                                        epoch=epoch + 1,
                                        config=config)

        if iteration % config.print_freq == 0:
            t_finish = time.time()
            t_elapsed = t_finish - t_start
            t_start = time.time()
            print('{}: iter {} epoch {} || Loss: {:.4} || Time {:.4}s || lr: {} || CR loss: {}'.format(
                config.rank, iteration, epoch, model_loss.item(), t_elapsed, optimizer.param_groups[0]['lr'],
                loss_comp.item() if isinstance(loss_comp, torch.Tensor) else loss_comp
            ))
Exemplo n.º 5
0
def main_worker(current_gpu, config):
    #################################
    # Setup experiment environment
    #################################
    config.current_gpu = current_gpu
    config.distributed = config.execution_mode in (
        ExecutionMode.DISTRIBUTED, ExecutionMode.MULTIPROCESSING_DISTRIBUTED)
    if config.distributed:
        configure_distributed(config)
    if is_on_first_rank(config):
        configure_logging(logger, config)
        print_args(config)

    config.device = get_device(config)
    config.start_iter = 0

    ##########################
    # Prepare metrics log file
    ##########################

    if config.metrics_dump is not None:
        write_metrics(0, config.metrics_dump)

    ###########################
    # Criterion
    ###########################

    criterion = MultiBoxLoss(config,
                             config['num_classes'],
                             overlap_thresh=0.5,
                             prior_for_matching=True,
                             bkg_label=0,
                             neg_mining=True,
                             neg_pos=3,
                             neg_overlap=0.5,
                             encode_target=False,
                             device=config.device)

    train_data_loader = test_data_loader = None
    resuming_checkpoint_path = config.resuming_checkpoint_path

    ###########################
    # Prepare data
    ###########################

    pretrained = is_pretrained_model_requested(config)

    if config.to_onnx is not None:
        assert pretrained or (resuming_checkpoint_path is not None)
    else:
        test_data_loader, train_data_loader = create_dataloaders(config)
        config.nncf_config = register_default_init_args(
            config.nncf_config, criterion, train_data_loader)

    ##################
    # Prepare model
    ##################
    resuming_checkpoint_path = config.resuming_checkpoint_path
    resuming_checkpoint = None
    resuming_model_state_dict = None

    if resuming_checkpoint_path:
        logger.info(
            'Resuming from checkpoint {}...'.format(resuming_checkpoint_path))
        resuming_checkpoint = torch.load(resuming_checkpoint_path,
                                         map_location='cpu')
        # use checkpoint itself in case only the state dict was saved,
        # i.e. the checkpoint was created with `torch.save(module.state_dict())`
        resuming_model_state_dict = resuming_checkpoint.get(
            'state_dict', resuming_checkpoint)

    compression_ctrl, net = create_model(config, resuming_model_state_dict)
    if config.distributed:
        config.batch_size //= config.ngpus_per_node
        config.workers //= config.ngpus_per_node
        compression_ctrl.distributed()

    ###########################
    # Optimizer
    ###########################

    params_to_optimize = get_parameter_groups(net, config)
    optimizer, lr_scheduler = make_optimizer(params_to_optimize, config)

    #################################
    # Load additional checkpoint data
    #################################

    if resuming_checkpoint is not None and config.mode.lower(
    ) == 'train' and config.to_onnx is None:
        compression_ctrl.scheduler.load_state_dict(
            resuming_checkpoint['scheduler'])
        optimizer.load_state_dict(
            resuming_checkpoint.get('optimizer', optimizer.state_dict()))
        config.start_iter = resuming_checkpoint.get('iter', 0) + 1

    if config.to_onnx:
        compression_ctrl.export_model(config.to_onnx)
        logger.info("Saved to {}".format(config.to_onnx))
        return

    if config.mode.lower() == 'test':
        with torch.no_grad():
            print_statistics(compression_ctrl.statistics())
            net.eval()
            mAp = test_net(net,
                           config.device,
                           test_data_loader,
                           distributed=config.distributed)
            if config.metrics_dump is not None:
                write_metrics(mAp, config.metrics_dump)
            return

    train(net, compression_ctrl, train_data_loader, test_data_loader,
          criterion, optimizer, config, lr_scheduler)
Exemplo n.º 6
0
def train(net, compression_ctrl, train_data_loader, test_data_loader,
          criterion, optimizer, config, lr_scheduler):
    net.train()
    # loss counters
    loc_loss = 0  # epoch
    conf_loss = 0

    epoch_size = len(train_data_loader)
    logger.info('Training {} on {} dataset...'.format(
        config.model, train_data_loader.dataset.name))
    batch_iterator = None

    t_start = time.time()
    print_statistics(compression_ctrl.statistics())

    best_mAp = 0
    best_compression_level = CompressionLevel.NONE
    test_freq_in_epochs = max(config.test_interval // epoch_size, 1)

    for iteration in range(config.start_iter, config['max_iter']):
        if (not batch_iterator) or (iteration % epoch_size == 0):
            # create batch iterator
            batch_iterator = iter(train_data_loader)

        epoch = iteration // epoch_size

        if (iteration + 1) % epoch_size == 0:
            compression_ctrl.scheduler.epoch_step(epoch)
            compression_level = compression_ctrl.compression_level()
            is_best = False

            if (epoch + 1) % test_freq_in_epochs == 0:
                if is_on_first_rank(config):
                    print_statistics(compression_ctrl.statistics())
                with torch.no_grad():
                    net.eval()
                    mAP = test_net(
                        net,
                        config.device,
                        test_data_loader,
                        distributed=config.multiprocessing_distributed)
                    is_best_by_mAP = mAP > best_mAp and compression_level == best_compression_level
                    is_best = is_best_by_mAP or compression_level > best_compression_level
                    if is_best:
                        best_mAp = mAP
                    best_compression_level = max(compression_level,
                                                 best_compression_level)
                    net.train()

            # Learning rate scheduling should be applied after optimizer’s update
            if not isinstance(lr_scheduler, ReduceLROnPlateau):
                lr_scheduler.step(epoch)
            else:
                lr_scheduler.step(mAP)

            if is_on_first_rank(config):
                logger.info('Saving state, iter: {}'.format(iteration))

                checkpoint_file_path = osp.join(
                    config.checkpoint_save_dir,
                    "{}_last.pth".format(get_name(config)))
                torch.save(
                    {
                        'state_dict': net.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'iter': config['max_iter'],
                        'scheduler': compression_ctrl.scheduler.state_dict(),
                        'compression_level': compression_level,
                    }, str(checkpoint_file_path))
                make_additional_checkpoints(checkpoint_file_path,
                                            is_best=is_best,
                                            epoch=epoch + 1,
                                            config=config)

        compression_ctrl.scheduler.step(iteration - config.start_iter)

        optimizer.zero_grad()
        batch_iterator, batch_loss, batch_loss_c, batch_loss_l, loss_comp = train_step(
            batch_iterator, compression_ctrl, config, criterion, net,
            train_data_loader)
        optimizer.step()

        batch_loss_l = batch_loss_l / config.iter_size
        batch_loss_c = batch_loss_c / config.iter_size
        model_loss = (batch_loss_l + batch_loss_c) / config.iter_size
        batch_loss = batch_loss / config.iter_size

        loc_loss += batch_loss_l.item()
        conf_loss += batch_loss_c.item()

        ###########################
        # Logging
        ###########################

        if is_on_first_rank(config):
            config.tb.add_scalar("train/loss_l", batch_loss_l.item(),
                                 iteration)
            config.tb.add_scalar("train/loss_c", batch_loss_c.item(),
                                 iteration)
            config.tb.add_scalar("train/loss", batch_loss.item(), iteration)

        if iteration % config.print_freq == 0:
            t_finish = time.time()
            t_elapsed = t_finish - t_start
            t_start = time.time()
            logger.info(
                '{}: iter {} epoch {} || Loss: {:.4} || Time {:.4}s || lr: {} || CR loss: {}'
                .format(
                    config.rank, iteration, epoch, model_loss.item(),
                    t_elapsed, optimizer.param_groups[0]['lr'],
                    loss_comp.item()
                    if isinstance(loss_comp, torch.Tensor) else loss_comp))

    if config.metrics_dump is not None:
        write_metrics(best_mAp, config.metrics_dump)
Exemplo n.º 7
0
def main_worker(current_gpu, config):
    #################################
    # Setup experiment environment
    #################################
    configure_device(current_gpu, config)
    config.mlflow = SafeMLFLow(config)
    if is_on_first_rank(config):
        configure_logging(logger, config)
        print_args(config)

    config.start_iter = 0
    nncf_config = config.nncf_config
    ##########################
    # Prepare metrics log file
    ##########################

    if config.metrics_dump is not None:
        write_metrics(0, config.metrics_dump)

    ###########################
    # Criterion
    ###########################

    criterion = MultiBoxLoss(config,
                             config['num_classes'],
                             overlap_thresh=0.5,
                             prior_for_matching=True,
                             bkg_label=0,
                             neg_mining=True,
                             neg_pos=3,
                             neg_overlap=0.5,
                             encode_target=False,
                             device=config.device)

    train_data_loader = test_data_loader = None
    resuming_checkpoint_path = config.resuming_checkpoint_path

    ###########################
    # Prepare data
    ###########################

    pretrained = is_pretrained_model_requested(config)

    if config.to_onnx is not None:
        assert pretrained or (resuming_checkpoint_path is not None)
    else:
        test_data_loader, train_data_loader, init_data_loader = create_dataloaders(
            config)

        def criterion_fn(model_outputs, target, criterion):
            loss_l, loss_c = criterion(model_outputs, target)
            return loss_l + loss_c

        def autoq_test_fn(model, eval_loader):
            # RL is maximization, change the loss polarity
            return -1 * test_net(model,
                                 config.device,
                                 eval_loader,
                                 distributed=config.distributed,
                                 loss_inference=True,
                                 criterion=criterion)

        nncf_config = register_default_init_args(nncf_config, init_data_loader,
                                                 criterion, criterion_fn,
                                                 autoq_test_fn,
                                                 test_data_loader,
                                                 config.device)

    ##################
    # Prepare model
    ##################
    resuming_checkpoint_path = config.resuming_checkpoint_path

    resuming_model_sd = None
    if resuming_checkpoint_path is not None:
        resuming_model_sd, resuming_checkpoint = load_resuming_model_state_dict_and_checkpoint_from_path(
            resuming_checkpoint_path)

    compression_ctrl, net = create_model(config, resuming_model_sd)
    if config.distributed:
        config.batch_size //= config.ngpus_per_node
        config.workers //= config.ngpus_per_node
        compression_ctrl.distributed()

    ###########################
    # Optimizer
    ###########################

    params_to_optimize = get_parameter_groups(net, config)
    optimizer, lr_scheduler = make_optimizer(params_to_optimize, config)

    #################################
    # Load additional checkpoint data
    #################################

    if resuming_checkpoint_path is not None and config.mode.lower(
    ) == 'train' and config.to_onnx is None:
        compression_ctrl.scheduler.load_state_dict(
            resuming_checkpoint['scheduler'])
        optimizer.load_state_dict(
            resuming_checkpoint.get('optimizer', optimizer.state_dict()))
        config.start_iter = resuming_checkpoint.get('iter', 0) + 1

    log_common_mlflow_params(config)

    if config.to_onnx:
        compression_ctrl.export_model(config.to_onnx)
        logger.info("Saved to {}".format(config.to_onnx))
        return

    if is_main_process():
        print_statistics(compression_ctrl.statistics())

    if config.mode.lower() == 'test':
        with torch.no_grad():
            net.eval()
            if config['ssd_params'].get('loss_inference', False):
                model_loss = test_net(net,
                                      config.device,
                                      test_data_loader,
                                      distributed=config.distributed,
                                      loss_inference=True,
                                      criterion=criterion)
                logger.info("Final model loss: {:.3f}".format(model_loss))
            else:
                mAp = test_net(net,
                               config.device,
                               test_data_loader,
                               distributed=config.distributed)
                if config.metrics_dump is not None:
                    write_metrics(mAp, config.metrics_dump)
            return

    train(net, compression_ctrl, train_data_loader, test_data_loader,
          criterion, optimizer, config, lr_scheduler)