Ejemplo n.º 1
0
def get_prune_weights(model, use_transformer=False):
    """Get variables for pruning."""
    # ['features.2.ops.0.1.1.weight', 'features.2.ops.1.1.1.weight', 'features.2.ops.2.1.1.weight'...]
    if use_transformer:
        return get_params_by_name(mc.unwrap_model(model),
                                  FLAGS._bn_to_prune_transformer.weight)
    return get_params_by_name(mc.unwrap_model(model),
                              FLAGS._bn_to_prune.weight)
Ejemplo n.º 2
0
def get_ema_model(ema, model_wrapper):
    """Generate model from ExponentialMovingAverage.

    NOTE: If `ema` is given, generate a new model wrapper. Otherwise directly
        return `model_wrapper`, in this case modifying `model_wrapper` also
        influence the following process.
    """
    if ema is not None:
        model_eval_wrapper = copy.deepcopy(model_wrapper)
        model_eval = unwrap_model(model_eval_wrapper)
        names = ema.average_names()
        params = get_params_by_name(model_eval, names)
        for name, param in zip(names, params):
            param.data.copy_(ema.average(name))
    else:
        model_eval_wrapper = model_wrapper
    return model_eval_wrapper
Ejemplo n.º 3
0
def run_one_epoch(epoch,
                  loader,
                  model,
                  criterion,
                  optimizer,
                  lr_scheduler,
                  ema,
                  meters,
                  max_iter=None,
                  phase='train'):
    """Run one epoch."""
    assert phase in ['train', 'val', 'test', 'bn_calibration'
                     ], "phase not be in train/val/test/bn_calibration."
    train = phase == 'train'
    if train:
        model.train()
    else:
        model.eval()
    if phase == 'bn_calibration':
        model.apply(bn_calibration)

    if FLAGS.use_distributed:
        loader.sampler.set_epoch(epoch)

    results = None
    data_iterator = iter(loader)
    if FLAGS.use_distributed:
        data_fetcher = dataflow.DataPrefetcher(data_iterator)
    else:
        # TODO(meijieru): prefetch for non distributed
        logging.warning('Not use prefetcher')
        data_fetcher = data_iterator
    for batch_idx, (input, target) in enumerate(data_fetcher):
        # used for bn calibration
        if max_iter is not None:
            assert phase == 'bn_calibration'
            if batch_idx >= max_iter:
                break

        target = target.cuda(non_blocking=True)
        if train:
            optimizer.zero_grad()
            loss = mc.forward_loss(model, criterion, input, target, meters)
            loss_l2 = optim.cal_l2_loss(model, FLAGS.weight_decay,
                                        FLAGS.weight_decay_method)
            loss = loss + loss_l2
            loss.backward()
            if FLAGS.use_distributed:
                udist.allreduce_grads(model)

            if FLAGS._global_step % FLAGS.log_interval == 0:
                results = mc.reduce_and_flush_meters(meters)
                if udist.is_master():
                    logging.info('Epoch {}/{} Iter {}/{} {}: '.format(
                        epoch, FLAGS.num_epochs, batch_idx, len(loader), phase)
                                 + ', '.join('{}: {:.4f}'.format(k, v)
                                             for k, v in results.items()))
                    for k, v in results.items():
                        mc.summary_writer.add_scalar('{}/{}'.format(phase, k),
                                                     v, FLAGS._global_step)
            if udist.is_master(
            ) and FLAGS._global_step % FLAGS.log_interval == 0:
                mc.summary_writer.add_scalar('train/learning_rate',
                                             optimizer.param_groups[0]['lr'],
                                             FLAGS._global_step)
                mc.summary_writer.add_scalar('train/l2_regularize_loss',
                                             extract_item(loss_l2),
                                             FLAGS._global_step)
                mc.summary_writer.add_scalar(
                    'train/current_epoch',
                    FLAGS._global_step / FLAGS._steps_per_epoch,
                    FLAGS._global_step)
                if FLAGS.data_loader_workers > 0:
                    mc.summary_writer.add_scalar(
                        'data/train/prefetch_size',
                        get_data_queue_size(data_iterator), FLAGS._global_step)

            optimizer.step()
            lr_scheduler.step()
            if FLAGS.use_distributed and FLAGS.allreduce_bn:
                udist.allreduce_bn(model)
            FLAGS._global_step += 1

            # NOTE: after steps count upate
            if ema is not None:
                model_unwrap = mc.unwrap_model(model)
                ema_names = ema.average_names()
                params = get_params_by_name(model_unwrap, ema_names)
                for name, param in zip(ema_names, params):
                    ema(name, param, FLAGS._global_step)
        else:
            mc.forward_loss(model, criterion, input, target, meters)

    if not train:
        results = mc.reduce_and_flush_meters(meters)
        if udist.is_master():
            logging.info(
                'Epoch {}/{} {}: '.format(epoch, FLAGS.num_epochs, phase) +
                ', '.join('{}: {:.4f}'.format(k, v)
                          for k, v in results.items()))
            for k, v in results.items():
                mc.summary_writer.add_scalar('{}/{}'.format(phase, k), v,
                                             FLAGS._global_step)
    return results
Ejemplo n.º 4
0
def run_one_epoch(epoch,
                  loader,
                  model,
                  criterion,
                  optimizer,
                  lr_scheduler,
                  ema,
                  rho_scheduler,
                  meters,
                  max_iter=None,
                  phase='train'):
    """Run one epoch."""
    assert phase in [
        'train', 'val', 'test', 'bn_calibration'
    ] or phase.startswith(
        'prune'), "phase not be in train/val/test/bn_calibration/prune."
    train = phase == 'train'
    if train:
        model.train()
    else:
        model.eval()
    if phase == 'bn_calibration':
        model.apply(bn_calibration)

    if not FLAGS.use_hdfs:
        if FLAGS.use_distributed:
            loader.sampler.set_epoch(epoch)

    results = None
    data_iterator = iter(loader)
    if not FLAGS.use_hdfs:
        if FLAGS.use_distributed:
            if FLAGS.dataset == 'coco':
                data_fetcher = dataflow.DataPrefetcherKeypoint(data_iterator)
            else:
                data_fetcher = dataflow.DataPrefetcher(data_iterator)
        else:
            logging.warning('Not use prefetcher')
            data_fetcher = data_iterator

    for batch_idx, data in enumerate(data_fetcher):
        if FLAGS.dataset == 'coco':
            input, target, target_weight, meta = data
            # print(input.shape, target.shape, target_weight.shape, meta)
            # (4, 3, 384, 288), (4, 17, 96, 72), (4, 17, 1),
        else:
            input, target = data
        # if batch_idx > 400:
        #     break
        # used for bn calibration
        if max_iter is not None:
            assert phase == 'bn_calibration'
            if batch_idx >= max_iter:
                break

        target = target.cuda(non_blocking=True)
        if train:
            optimizer.zero_grad()
            rho = rho_scheduler(FLAGS._global_step)

            if FLAGS.dataset == 'coco':
                outputs = model(input)
                if isinstance(outputs, list):
                    loss = criterion(outputs[0], target, target_weight)
                    for output in outputs[1:]:
                        loss += criterion(output, target, target_weight)
                else:
                    output = outputs
                    loss = criterion(output, target, target_weight)
                _, avg_acc, cnt, pred = accuracy_keypoint(
                    output.detach().cpu().numpy(),
                    target.detach().cpu().numpy())  # cnt=17
                meters['acc'].cache(avg_acc)
                meters['loss'].cache(loss)
            else:
                loss = mc.forward_loss(model,
                                       criterion,
                                       input,
                                       target,
                                       meters,
                                       task=FLAGS.model_kwparams.task,
                                       distill=FLAGS.distill)
            if FLAGS.prune_params['method'] is not None:
                loss_l2 = optim.cal_l2_loss(
                    model, FLAGS.weight_decay,
                    FLAGS.weight_decay_method)  # manual weight decay
                loss_bn_l1 = prune.cal_bn_l1_loss(get_prune_weights(model),
                                                  FLAGS._bn_to_prune.penalty,
                                                  rho)
                if FLAGS.prune_params.use_transformer:

                    transformer_weights = get_prune_weights(model, True)
                    loss_bn_l1 += prune.cal_bn_l1_loss(
                        transformer_weights,
                        FLAGS._bn_to_prune_transformer.penalty, rho)

                    transformer_dict = []
                    for name, weight in zip(
                            FLAGS._bn_to_prune_transformer.weight,
                            transformer_weights):
                        transformer_dict.append(
                            sum(weight > FLAGS.model_shrink_threshold).item())
                    FLAGS._bn_to_prune_transformer.add_info_list(
                        'channels', transformer_dict)
                    FLAGS._bn_to_prune_transformer.update_penalty()
                    if udist.is_master(
                    ) and FLAGS._global_step % FLAGS.log_interval == 0:
                        logging.info(transformer_dict)
                        # logging.info(FLAGS._bn_to_prune_transformer.penalty)

                meters['loss_l2'].cache(loss_l2)
                meters['loss_bn_l1'].cache(loss_bn_l1)
                loss = loss + loss_l2 + loss_bn_l1
            loss.backward()
            if FLAGS.use_distributed:
                udist.allreduce_grads(model)

            if FLAGS._global_step % FLAGS.log_interval == 0:
                results = mc.reduce_and_flush_meters(meters)
                if udist.is_master():
                    logging.info('Epoch {}/{} Iter {}/{} Lr: {} {}: '.format(
                        epoch, FLAGS.num_epochs, batch_idx, len(loader),
                        optimizer.param_groups[0]["lr"], phase) +
                                 ', '.join('{}: {:.4f}'.format(k, v)
                                           for k, v in results.items()))
                    for k, v in results.items():
                        mc.summary_writer.add_scalar('{}/{}'.format(phase, k),
                                                     v, FLAGS._global_step)

            if udist.is_master(
            ) and FLAGS._global_step % FLAGS.log_interval == 0:
                mc.summary_writer.add_scalar('train/learning_rate',
                                             optimizer.param_groups[0]['lr'],
                                             FLAGS._global_step)
                if FLAGS.prune_params['method'] is not None:
                    mc.summary_writer.add_scalar('train/l2_regularize_loss',
                                                 extract_item(loss_l2),
                                                 FLAGS._global_step)
                    mc.summary_writer.add_scalar('train/bn_l1_loss',
                                                 extract_item(loss_bn_l1),
                                                 FLAGS._global_step)
                mc.summary_writer.add_scalar('prune/rho', rho,
                                             FLAGS._global_step)
                mc.summary_writer.add_scalar(
                    'train/current_epoch',
                    FLAGS._global_step / FLAGS._steps_per_epoch,
                    FLAGS._global_step)
                if FLAGS.data_loader_workers > 0:
                    mc.summary_writer.add_scalar(
                        'data/train/prefetch_size',
                        get_data_queue_size(data_iterator), FLAGS._global_step)

            if udist.is_master(
            ) and FLAGS._global_step % FLAGS.log_interval_detail == 0:
                summary_bn(model, 'train')

            optimizer.step()
            if FLAGS.lr_scheduler == 'poly':
                optim.poly_learning_rate(
                    optimizer, FLAGS.lr,
                    epoch * FLAGS._steps_per_epoch + batch_idx + 1,
                    FLAGS.num_epochs * FLAGS._steps_per_epoch)
            else:
                lr_scheduler.step()
            if FLAGS.use_distributed and FLAGS.allreduce_bn:
                udist.allreduce_bn(model)
            FLAGS._global_step += 1

            # NOTE: after steps count update
            if ema is not None:
                model_unwrap = mc.unwrap_model(model)
                ema_names = ema.average_names()
                params = get_params_by_name(model_unwrap, ema_names)
                for name, param in zip(ema_names, params):
                    ema(name, param, FLAGS._global_step)
        else:
            if FLAGS.dataset == 'coco':
                outputs = model(input)
                if isinstance(outputs, list):
                    loss = criterion(outputs[0], target, target_weight)
                    for output in outputs[1:]:
                        loss += criterion(output, target, target_weight)
                else:
                    output = outputs
                    loss = criterion(output, target, target_weight)
                _, avg_acc, cnt, pred = accuracy_keypoint(
                    output.detach().cpu().numpy(),
                    target.detach().cpu().numpy())  # cnt=17
                meters['acc'].cache(avg_acc)
                meters['loss'].cache(loss)
            else:
                mc.forward_loss(model,
                                criterion,
                                input,
                                target,
                                meters,
                                task=FLAGS.model_kwparams.task,
                                distill=False)

    if not train:
        results = mc.reduce_and_flush_meters(meters)
        if udist.is_master():
            logging.info(
                'Epoch {}/{} {}: '.format(epoch, FLAGS.num_epochs, phase) +
                ', '.join('{}: {:.4f}'.format(k, v)
                          for k, v in results.items()))
            for k, v in results.items():
                mc.summary_writer.add_scalar('{}/{}'.format(phase, k), v,
                                             FLAGS._global_step)
    return results
Ejemplo n.º 5
0
def get_prune_weights(model):
    """Get variables for pruning."""
    return get_params_by_name(unwrap_model(model), FLAGS._bn_to_prune.weight)