Пример #1
0
    def backward(self, grad_output):
        grad_output = grad_output.contiguous()
        saved_input, weight, mean, invstd, count_all = self.saved_tensors
        need_input_grad, need_weight_grad, need_bias_grad = self.needs_input_grad[
            0:3]

        # calculate local stats as well as grad_weight / grad_bias
        sum_dy, sum_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce(
            grad_output, saved_input, mean, invstd, weight, need_input_grad,
            need_weight_grad, need_bias_grad)

        if need_input_grad:
            # synchronizing stats used to calculate input gradient.
            sum_dy_handle = allreduce_async(sum_dy,
                                            op=Sum,
                                            name='sync_batch_norm.sum_dy')
            sum_dy_xmu_handle = allreduce_async(
                sum_dy_xmu, op=Sum, name='sync_batch_norm.sum_dy_xmu')

            # wait on the async communication to finish
            sum_dy = synchronize(sum_dy_handle)
            sum_dy_xmu = synchronize(sum_dy_xmu_handle)

            if _SYNC_BN_V2 or _SYNC_BN_V3:
                count_all_sum = count_all.sum()
                mean_dy = sum_dy / count_all_sum
                mean_dy_xmu = sum_dy_xmu / count_all_sum
            else:
                # before 1.5.0, sum_dy was sum of means from every worker, so we just
                # need to divide it by number of workers
                mean_dy = sum_dy / size()
                mean_dy_xmu = sum_dy_xmu / size()

            # backward pass for gradient calculation
            grad_input = torch.batch_norm_backward_elemt(
                grad_output, saved_input, mean, invstd, weight, mean_dy,
                mean_dy_xmu)
        else:
            grad_input = None

        # synchronizing of grad_weight / grad_bias is not needed as distributed
        # training would handle all reduce.
        if weight is None or not need_weight_grad:
            grad_weight = None

        if weight is None or not need_bias_grad:
            grad_bias = None

        return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
def upsnet_train():

    if is_master:
        logger.info('training config:{}\n'.format(pprint.pformat(config)))
    gpus = [torch.device('cuda', int(_)) for _ in config.gpus.split(',')]
    num_replica = hvd.size() if config.train.use_horovod else len(gpus)
    num_gpus = 1 if config.train.use_horovod else len(gpus)

    # create models
    train_model = eval(config.symbol)().cuda()
        
    # create optimizer
    params_lr = train_model.get_params_lr()
    # we use custom optimizer and pass lr=1 to support different lr for different weights
    optimizer = SGD(params_lr, lr=1, momentum=config.train.momentum, weight_decay=config.train.wd)
    if config.train.use_horovod:
        optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=train_model.named_parameters())
    optimizer.zero_grad()

    # create data loader
    train_dataset = eval(config.dataset.dataset)(image_sets=config.dataset.image_set.split('+'), flip=config.train.flip, result_path=final_output_path)
    val_dataset = eval(config.dataset.dataset)(image_sets=config.dataset.test_image_set.split('+'), flip=False, result_path=final_output_path, phase='val')
    if config.train.use_horovod:
        train_sampler = distributed.DistributedSampler(train_dataset, num_replicas=hvd.size(), rank=hvd.rank())
        val_sampler = distributed.DistributedSampler(val_dataset, num_replicas=hvd.size(), rank=hvd.rank())
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.train.batch_size, sampler=train_sampler, num_workers=num_gpus * 4, drop_last=False, collate_fn=train_dataset.collate)
        val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=config.train.batch_size, sampler=val_sampler, num_workers=num_gpus * 4, drop_last=False, collate_fn=val_dataset.collate)
    else:
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.train.batch_size, shuffle=config.train.shuffle, num_workers=num_gpus * 4 if not config.debug_mode else num_gpus * 4, drop_last=False, collate_fn=train_dataset.collate)
        val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=config.train.batch_size, shuffle=False, num_workers=num_gpus * 4 if not config.debug_mode else num_gpus * 4, drop_last=False, collate_fn=val_dataset.collate)

    # preparing
    curr_iter = config.train.begin_iteration
    batch_end_callback = [Speedometer(num_replica * config.train.batch_size, config.train.display_iter)]
    metrics = []
    metrics_name = []
    if config.network.has_rpn:
        metrics.extend([AvgMetric(name='rpn_cls_loss'), AvgMetric(name='rpn_bbox_loss'),])
        metrics_name.extend(['rpn_cls_loss', 'rpn_bbox_loss'])
    if config.network.has_rcnn:
        metrics.extend([AvgMetric(name='rcnn_accuracy'), AvgMetric(name='cls_loss'), AvgMetric(name='bbox_loss'),])
        metrics_name.extend(['rcnn_accuracy', 'cls_loss', 'bbox_loss'])
    if config.network.has_mask_head:
        metrics.extend([AvgMetric(name='mask_loss'), ])
        metrics_name.extend(['mask_loss'])
    if config.network.has_fcn_head:
        metrics.extend([AvgMetric(name='fcn_loss'), ])
        metrics_name.extend(['fcn_loss'])
        if config.train.fcn_with_roi_loss:
            metrics.extend([AvgMetric(name='fcn_roi_loss'), ])
            metrics_name.extend(['fcn_roi_loss'])
    if config.network.has_panoptic_head:
        metrics.extend([AvgMetric(name='panoptic_accuracy'), AvgMetric(name='panoptic_loss'), ])
        metrics_name.extend(['panoptic_accuracy', 'panoptic_loss'])

    if config.train.resume:
        train_model.load_state_dict(torch.load(os.path.join(final_output_path, config.model_prefix + str(curr_iter) + '.pth')), resume=True)
        optimizer.load_state_dict(torch.load(os.path.join(final_output_path, config.model_prefix + str(curr_iter) + '.state.pth')))
        if config.train.use_horovod:
            hvd.broadcast_parameters(train_model.state_dict(), root_rank=0)
    else:
        if is_master:
            train_model.load_state_dict(torch.load(config.network.pretrained))

        if config.train.use_horovod:
            hvd.broadcast_parameters(train_model.state_dict(), root_rank=0)

    if not config.train.use_horovod:
        train_model = DataParallel(train_model, device_ids=[int(_) for _ in config.gpus.split(',')]).to(gpus[0])

    if is_master:
        batch_end_callback[0](0, 0)

    train_model.eval()

    # start training
    while curr_iter < config.train.max_iteration:
        if config.train.use_horovod:
            train_sampler.set_epoch(curr_iter)

            if config.network.use_syncbn:
                train_model.train()
                if config.network.backbone_freeze_at > 0:
                    train_model.freeze_backbone(config.network.backbone_freeze_at)
                if config.network.backbone_fix_bn:
                    train_model.resnet_backbone.eval()


            for inner_iter, batch in enumerate(train_loader):
                data, label, _ = batch
                for k, v in data.items():
                    data[k] = v if not torch.is_tensor(v) else v.cuda()
                for k, v in label.items():
                    label[k] = v if not torch.is_tensor(v) else v.cuda()

                lr = adjust_learning_rate(optimizer, curr_iter, config)
                optimizer.zero_grad()
                output = train_model(data, label)
                loss = 0
                if config.network.has_rpn:
                    loss = loss + output['rpn_cls_loss'].mean() + output['rpn_bbox_loss'].mean()
                if config.network.has_rcnn:
                    loss = loss + output['cls_loss'].mean() + output['bbox_loss'].mean() * config.train.bbox_loss_weight
                if config.network.has_mask_head:
                    loss = loss + output['mask_loss'].mean()
                if config.network.has_fcn_head:
                    loss = loss + output['fcn_loss'].mean() * config.train.fcn_loss_weight
                    if config.train.fcn_with_roi_loss:
                        loss = loss + output['fcn_roi_loss'].mean() * config.train.fcn_loss_weight * 0.2
                if config.network.has_panoptic_head:
                    loss = loss + output['panoptic_loss'].mean() * config.train.panoptic_loss_weight
                loss.backward()
                optimizer.step(lr)

                losses = []
                losses.append(allreduce_async(loss, name='train_total_loss'))
                for l in metrics_name:
                    losses.append(allreduce_async(output[l].mean(), name=l))

                loss = hvd.synchronize(losses[0]).item()
                if is_master:
                    writer.add_scalar('train_total_loss', loss, curr_iter)
                for i, (metric, l) in enumerate(zip(metrics, metrics_name)):
                    loss = hvd.synchronize(losses[i + 1]).item()
                    if is_master:
                        writer.add_scalar('train_' + l, loss, curr_iter)
                        metric.update(_, _, loss)
                curr_iter += 1


                if curr_iter in config.train.decay_iteration:
                    if is_master:
                        logger.info('decay momentum buffer')
                    for k in optimizer.state_dict()['state'].keys():
                        if 'momentum_buffer' in optimizer.state_dict()['state'][k]:
                            optimizer.state_dict()['state'][k]['momentum_buffer'].div_(10)

                if is_master:
                    if curr_iter % config.train.display_iter == 0:
                        for callback in batch_end_callback:
                            callback(curr_iter, metrics)

                    if curr_iter % config.train.snapshot_step == 0:
                        logger.info('taking snapshot ...')
                        torch.save(train_model.state_dict(), os.path.join(final_output_path, config.model_prefix+str(curr_iter)+'.pth'))
                        torch.save(optimizer.state_dict(), os.path.join(final_output_path, config.model_prefix+str(curr_iter)+'.state.pth'))
        else:
            inner_iter = 0
            train_iterator = train_loader.__iter__()
            while inner_iter + num_gpus <= len(train_loader):
                batch = []
                for gpu_id in gpus:
                    data, label, _ = train_iterator.next()
                    for k, v in data.items():
                        data[k] = v if not torch.is_tensor(v) else v.pin_memory().to(gpu_id, non_blocking=True)
                    for k, v in label.items():
                        label[k] = v if not torch.is_tensor(v) else v.pin_memory().to(gpu_id, non_blocking=True)
                    batch.append((data, label))
                    inner_iter += 1
                lr = adjust_learning_rate(optimizer, curr_iter, config)
                optimizer.zero_grad()
                if config.train.use_horovod:
                    output = train_model(data, label)
                else:
                    output = train_model(*batch)

                loss = 0
                if config.network.has_rpn:
                    loss = loss + output['rpn_cls_loss'].mean() + output['rpn_bbox_loss'].mean()
                if config.network.has_rcnn:
                    loss = loss + output['cls_loss'].mean() + output['bbox_loss'].mean()
                if config.network.has_mask_head:
                    loss = loss + output['mask_loss'].mean()
                if config.network.has_fcn_head:
                    loss = loss + output['fcn_loss'].mean() * config.train.fcn_loss_weight
                    if config.train.fcn_with_roi_loss:
                        loss = loss + output['fcn_roi_loss'].mean() * config.train.fcn_loss_weight * 0.2
                if config.network.has_panoptic_head:
                    loss = loss + output['panoptic_loss'].mean() * config.train.panoptic_loss_weight
                loss.backward()
                optimizer.step(lr)
                
                losses = []
                losses.append(loss.item())
                for l in metrics_name:
                    losses.append(output[l].mean().item())

                loss = losses[0]
                if is_master:
                    writer.add_scalar('train_total_loss', loss, curr_iter)
                for i, (metric, l) in enumerate(zip(metrics, metrics_name)):
                    loss = losses[i + 1]
                    if is_master:
                        writer.add_scalar('train_' + l, loss, curr_iter)
                        metric.update(_, _, loss)
                curr_iter += 1

                if curr_iter in config.train.decay_iteration:
                    if is_master:
                        logger.info('decay momentum buffer')
                    for k in optimizer.state_dict()['state'].keys():
                        optimizer.state_dict()['state'][k]['momentum_buffer'].div_(10)

                if is_master:
                    if curr_iter % config.train.display_iter == 0:
                        for callback in batch_end_callback:
                            callback(curr_iter, metrics)


                    if curr_iter % config.train.snapshot_step == 0:
                        logger.info('taking snapshot ...')
                        torch.save(train_model.module.state_dict(), os.path.join(final_output_path, config.model_prefix+str(curr_iter)+'.pth'))
                        torch.save(optimizer.state_dict(), os.path.join(final_output_path, config.model_prefix+str(curr_iter)+'.state.pth'))

            while True:
                try:
                    train_iterator.next()
                except:
                    break

        for metric in metrics:
            metric.reset()

        if config.train.eval_data:
            train_model.eval()

            if config.train.use_horovod:
                for inner_iter, batch in enumerate(val_loader):
                    data, label, _ = batch
                    for k, v in data.items():
                        data[k] = v if not torch.is_tensor(v) else v.cuda(non_blocking=True)
                    for k, v in label.items():
                        label[k] = v if not torch.is_tensor(v) else v.cuda(non_blocking=True)

                    with torch.no_grad():
                        output = train_model(data, label)

                    for metric, l in zip(metrics, metrics_name):
                        loss = hvd.allreduce(output[l].mean()).item()
                        if is_master:
                            metric.update(_, _, loss)

            else:
                inner_iter = 0
                val_iterator = val_loader.__iter__()
                while inner_iter + len(gpus) <= len(val_loader):
                    batch = []
                    for gpu_id in gpus:
                        data, label, _ = val_iterator.next()
                        for k, v in data.items():
                            data[k] = v if not torch.is_tensor(v) else v.pin_memory().to(gpu_id, non_blocking=True)
                        for k, v in label.items():
                            label[k] = v if not torch.is_tensor(v) else v.pin_memory().to(gpu_id, non_blocking=True)
                        batch.append((data, label))
                        inner_iter += 1

                    with torch.no_grad():
                        if config.train.use_horovod:
                            output = train_model(data, label)
                        else:
                            output = train_model(*batch)

                    losses = []
                    for l in metrics_name:
                        losses.append(allreduce_async(output[l].mean(), name=l) if config.train.use_horovod else output[l].mean().item())

                    for metric, loss in zip(metrics, losses):
                        loss = hvd.synchronize(loss).item() if config.train.use_horovod else loss
                        if is_master:
                            metric.update(_, _, loss)

                while True:
                    try:
                        val_iterator.next()
                    except Exception:
                        break

            s = 'Batch [%d]\t Epoch[%d]\t' % (curr_iter, curr_iter // len(train_loader))

            for metric in metrics:
                m, v = metric.get()
                s += 'Val-%s=%f,\t' % (m, v)
                if is_master:
                    writer.add_scalar('val_' + m, v, curr_iter)
                    metric.reset()
            if is_master:
                logger.info(s)

    if is_master and config.train.use_horovod:
        logger.info('taking snapshot ...')
        torch.save(train_model.state_dict(), os.path.join(final_output_path, config.model_prefix + str(curr_iter) + '.pth'))
        torch.save(optimizer.state_dict(), os.path.join(final_output_path, config.model_prefix + str(curr_iter) + '.state.pth'))
    elif not config.train.use_horovod:
        logger.info('taking snapshot ...')
        torch.save(train_model.module.state_dict(), os.path.join(final_output_path, config.model_prefix + str(curr_iter) + '.pth'))
        torch.save(optimizer.state_dict(), os.path.join(final_output_path, config.model_prefix + str(curr_iter) + '.state.pth'))
Пример #3
0
    def backward(self, grad_output):
        grad_output = grad_output.contiguous()
        saved_input, weight, mean, invstd, count_all = self.saved_tensors
        need_input_grad, need_weight_grad, need_bias_grad = self.needs_input_grad[
            0:3]

        # calculate local stats as well as grad_weight / grad_bias
        sum_dy, sum_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce(
            grad_output, saved_input, mean, invstd, weight, need_input_grad,
            need_weight_grad, need_bias_grad)

        if need_input_grad:
            # synchronizing stats used to calculate input gradient.
            sum_dy_handle = allreduce_async(sum_dy,
                                            op=Sum,
                                            name='sync_batch_norm.sum_dy')
            sum_dy_xmu_handle = allreduce_async(
                sum_dy_xmu, op=Sum, name='sync_batch_norm.sum_dy_xmu')

            # wait on the async communication to finish
            sum_dy = synchronize(sum_dy_handle)
            sum_dy_xmu = synchronize(sum_dy_xmu_handle)

            if _SYNC_BN_V4:
                # from 1.9.0 on we need a count tensor on all devices
                # count_all is calculated as total count across all ranks in forward function
                count_all = count_all.to(dtype=torch.int,
                                         device=grad_output.device)
            elif _SYNC_BN_V2 or _SYNC_BN_V3:
                # before 1.9.0 we need the count as an integer to compute means values
                count = count_all.sum()
            else:
                # before 1.5.0, sum_dy was sum of means from every worker, so we just
                # need to divide it by number of workers
                count = size()

            # backward pass for gradient calculation
            # we are calling into a non-public undocumented function which broke moving to 1.9.0
            # https://github.com/pytorch/pytorch/issues/57900
            if _SYNC_BN_V4:
                # from 1.9.0 on, sums and count parameters expected
                grad_input = torch.batch_norm_backward_elemt(
                    grad_output, saved_input, mean, invstd, weight, sum_dy,
                    sum_dy_xmu, count_all)
            else:
                # before 1.9.0, mean parameters expected, not sums and count
                grad_input = torch.batch_norm_backward_elemt(
                    grad_output, saved_input, mean, invstd, weight,
                    sum_dy / count, sum_dy_xmu / count)
        else:
            grad_input = None

        # synchronizing of grad_weight / grad_bias is not needed as distributed
        # training would handle all reduce.
        if weight is None or not need_weight_grad:
            grad_weight = None

        if weight is None or not need_bias_grad:
            grad_bias = None

        return grad_input, grad_weight, grad_bias, None, None, None, None, None, None