def validate(data_loader, model, epoch, logger=None):
    data_time_meter = AverageMeter()
    batch_time_meter = AverageMeter()

    model.eval()

    tic = time.time()
    loader_size = len(data_loader)

    training_states = TrainingStates()

    for i, (data_dicts) in enumerate(data_loader):
        data_time_meter.update(time.time() - tic)

        batch_size = data_dicts['point_cloud'].shape[0]

        with torch.no_grad():
            data_dicts_var = {
                key: value.cuda()
                for key, value in data_dicts.items()
            }

            losses, metrics = model(data_dicts_var)
            # mean for multi-gpu setting
            losses_reduce = {
                key: value.detach().mean().item()
                for key, value in losses.items()
            }
            metrics_reduce = {
                key: value.detach().mean().item()
                for key, value in metrics.items()
            }

        training_states.update_states(dict(**losses_reduce, **metrics_reduce),
                                      batch_size)

        batch_time_meter.update(time.time() - tic)
        tic = time.time()

    states = training_states.get_states(avg=True)

    states_str = training_states.format_states(states)
    output_str = 'Validation Epoch: {:03d} Time:{:.3f}/{:.3f} ' \
        .format(epoch + 1, data_time_meter.val, batch_time_meter.val)

    logging.info(output_str + states_str)

    if logger is not None:
        for tag, value in states.items():
            logger.scalar_summary(tag, value, int(epoch))

    return states['IoU_' + str(cfg.IOU_THRESH)]
def train(data_loader, model, optimizer, lr_scheduler, epoch, logger=None):

    data_time_meter = AverageMeter()
    batch_time_meter = AverageMeter()

    model.train()

    MIN_LR = cfg.TRAIN.MIN_LR
    lr_scheduler.step(epoch)
    if MIN_LR > 0:
        if lr_scheduler.get_lr()[0] < MIN_LR:
            for param_group in optimizer.param_groups:
                param_group['lr'] = MIN_LR

    cur_lr = optimizer.param_groups[0]['lr']
    # cur_mom = get_bn_decay(epoch)
    # set_module_bn_momentum(model, cur_mom)

    tic = time.time()
    loader_size = len(data_loader)

    training_states = TrainingStates()

    for i, (data_dicts) in enumerate(data_loader):

        data_time_meter.update(time.time() - tic)

        batch_size = data_dicts['point_cloud'].shape[0]

        data_dicts_var = {
            key: value.cuda()
            for key, value in data_dicts.items()
        }
        optimizer.zero_grad()

        losses, metrics = model(data_dicts_var)
        loss = losses['total_loss']

        loss = loss.mean()
        loss.backward()
        optimizer.step()

        # mean for multi-gpu setting
        losses_reduce = {
            key: value.detach().mean().item()
            for key, value in losses.items()
        }
        metrics_reduce = {
            key: value.detach().mean().item()
            for key, value in metrics.items()
        }

        training_states.update_states(dict(**losses_reduce, **metrics_reduce),
                                      batch_size)

        batch_time_meter.update(time.time() - tic)
        tic = time.time()

        if (i + 1) % cfg.disp == 0 or (i + 1) == loader_size:

            states = training_states.get_states(avg=False)

            states_str = training_states.format_states(states)
            output_str = 'Train Epoch: {:03d} [{:04d}/{}] lr:{:.6f} Time:{:.3f}/{:.3f} ' \
                .format(epoch + 1, i + 1, len(data_loader), cur_lr, data_time_meter.val, batch_time_meter.val)

            logging.info(output_str + states_str)

            if (i + 1) == loader_size:
                states = training_states.get_states(avg=True)
                states_str = training_states.format_states(states)
                output_str = 'Train Epoch(AVG): {:03d} [{:04d}/{}] lr:{:.6f} Time:{:.3f}/{:.3f} ' \
                    .format(epoch + 1, i + 1, len(data_loader), cur_lr, data_time_meter.val, batch_time_meter.val)
                logging.info(output_str + states_str)

        if logger is not None:
            states = training_states.get_states(avg=True)
            for tag, value in states.items():
                logger.scalar_summary(tag, value, int(epoch))