コード例 #1
0
def main():
    """Create the model and start the training."""
    global args
    args = get_arguments()
    if args.dist:
        init_dist(args.launcher, backend=args.backend)
    world_size = 1
    rank = 0
    if args.dist:
        rank = dist.get_rank()
        world_size = dist.get_world_size()
    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True

    # Create network
    if args.model == 'Deeplab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from, strict=False)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            i_parts = i.split('.')
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
        model.load_state_dict(new_params)
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)
        model.load_state_dict(saved_state_dict, strict=False)
    elif args.model == 'DeeplabVGGBN':
        deeplab_vggbn.BatchNorm = SyncBatchNorm2d
        model = deeplab_vggbn.DeeplabVGGBN(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)
            model.load_state_dict(saved_state_dict, strict=False)
            del saved_state_dict

    model.train()
    model.to(device)
    if args.dist:
        broadcast_params(model)

    if rank == 0:
        print(model)

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes).to(device)
    model_D2 = FCDiscriminator(num_classes=args.num_classes).to(device)

    model_D1.train()
    model_D1.to(device)
    if args.dist:
        broadcast_params(model_D1)
    if args.restore_D is not None:
        D_dict = torch.load(args.restore_D)
        model_D1.load_state_dict(D_dict, strict=False)
        del D_dict

    model_D2.train()
    model_D2.to(device)
    if args.dist:
        broadcast_params(model_D2)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    train_data = GTA5BDDDataSet(args.data_dir,
                                args.data_list,
                                max_iters=args.num_steps * args.iter_size *
                                args.batch_size,
                                crop_size=input_size,
                                scale=args.random_scale,
                                mirror=args.random_mirror,
                                mean=IMG_MEAN)
    train_sampler = None
    if args.dist:
        train_sampler = DistributedSampler(train_data)
    trainloader = data.DataLoader(train_data,
                                  batch_size=args.batch_size,
                                  shuffle=False if train_sampler else True,
                                  num_workers=args.num_workers,
                                  pin_memory=False,
                                  sampler=train_sampler)

    trainloader_iter = enumerate(cycle(trainloader))

    target_data = BDDDataSet(args.data_dir_target,
                             args.data_list_target,
                             max_iters=args.num_steps * args.iter_size *
                             args.batch_size,
                             crop_size=input_size_target,
                             scale=False,
                             mirror=args.random_mirror,
                             mean=IMG_MEAN,
                             set=args.set)
    target_sampler = None
    if args.dist:
        target_sampler = DistributedSampler(target_data)
    targetloader = data.DataLoader(target_data,
                                   batch_size=args.batch_size,
                                   shuffle=False if target_sampler else True,
                                   num_workers=args.num_workers,
                                   pin_memory=False,
                                   sampler=target_sampler)

    targetloader_iter = enumerate(cycle(targetloader))

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()
    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

    #interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    # set up tensor board
    if args.tensorboard and rank == 0:
        if not os.path.exists(args.log_dir):
            os.makedirs(args.log_dir)

        writer = SummaryWriter(args.log_dir)

    torch.cuda.empty_cache()
    for i_iter in range(args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            _, batch = trainloader_iter.__next__()

            images, labels, size, _ = batch
            images = images.to(device)
            labels = labels.long().to(device)
            interp = nn.Upsample(size=(size[1], size[0]),
                                 mode='bilinear',
                                 align_corners=True)

            pred1 = model(images)
            pred1 = interp(pred1)

            loss_seg1 = seg_loss(pred1, labels)

            loss = loss_seg1

            # proper normalization
            loss = loss / args.iter_size / world_size
            loss.backward()
            loss_seg_value1 += loss_seg1.item() / args.iter_size

            _, batch = targetloader_iter.__next__()
            # train with target
            images, _, _ = batch
            images = images.to(device)

            pred_target1 = model(images)
            pred_target1 = interp_target(pred_target1)

            D_out1 = model_D1(F.softmax(pred_target1))
            loss_adv_target1 = bce_loss(
                D_out1,
                torch.FloatTensor(
                    D_out1.data.size()).fill_(source_label).to(device))

            loss = args.lambda_adv_target1 * loss_adv_target1
            loss = loss / args.iter_size / world_size

            loss.backward()
            loss_adv_target_value1 += loss_adv_target1.item() / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred1 = pred1.detach()
            D_out1 = model_D1(F.softmax(pred1))
            loss_D1 = bce_loss(
                D_out1,
                torch.FloatTensor(
                    D_out1.data.size()).fill_(source_label).to(device))
            loss_D1 = loss_D1 / args.iter_size / 2 / world_size
            loss_D1.backward()
            loss_D_value1 += loss_D1.item()

            # train with target
            pred_target1 = pred_target1.detach()
            D_out1 = model_D1(F.softmax(pred_target1))
            loss_D1 = bce_loss(
                D_out1,
                torch.FloatTensor(
                    D_out1.data.size()).fill_(target_label).to(device))
            loss_D1 = loss_D1 / args.iter_size / 2 / world_size
            loss_D1.backward()
            if args.dist:
                average_gradients(model)
                average_gradients(model_D1)
                average_gradients(model_D2)

            loss_D_value1 += loss_D1.item()

        optimizer.step()
        optimizer_D1.step()

        if rank == 0:
            if args.tensorboard:
                scalar_info = {
                    'loss_seg1': loss_seg_value1,
                    'loss_seg2': loss_seg_value2,
                    'loss_adv_target1': loss_adv_target_value1,
                    'loss_adv_target2': loss_adv_target_value2,
                    'loss_D1': loss_D_value1 * world_size,
                    'loss_D2': loss_D_value2 * world_size,
                }

                if i_iter % 10 == 0:
                    for key, val in scalar_info.items():
                        writer.add_scalar(key, val, i_iter)

            print('exp = {}'.format(args.snapshot_dir))
            print(
                'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'
                .format(i_iter, args.num_steps, loss_seg_value1,
                        loss_seg_value2, loss_adv_target_value1,
                        loss_adv_target_value2, loss_D_value1, loss_D_value2))

            if i_iter >= args.num_steps_stop - 1:
                print('save model ...')
                torch.save(
                    model.state_dict(),
                    osp.join(args.snapshot_dir,
                             'GTA5_' + str(args.num_steps_stop) + '.pth'))
                torch.save(
                    model_D1.state_dict(),
                    osp.join(args.snapshot_dir,
                             'GTA5_' + str(args.num_steps_stop) + '_D1.pth'))
                break

            if i_iter % args.save_pred_every == 0 and i_iter != 0:
                print('taking snapshot ...')
                torch.save(
                    model.state_dict(),
                    osp.join(args.snapshot_dir,
                             'GTA5_' + str(i_iter) + '.pth'))
                torch.save(
                    model_D1.state_dict(),
                    osp.join(args.snapshot_dir,
                             'GTA5_' + str(i_iter) + '_D1.pth'))
    print(args.snapshot_dir)
    if args.tensorboard and rank == 0:
        writer.close()
コード例 #2
0
def train(train_loader, model, lr_scheduler, epoch, cfg, warmup=False):
    logger = logging.getLogger('global')

    model.cuda()
    model.train()
    world_size = 1
    rank = 0
    if args.dist:
        rank = dist.get_rank()
        world_size = dist.get_world_size()

    def freeze_bn(m):
        classname = m.__class__.__name__
        if classname.find('BatchNorm') != -1:
            m.eval()

    model.apply(freeze_bn)
    logger.info('freeze bn')

    t0 = time.time()

    if args.dist:
        # update random seed
        train_loader.sampler.set_epoch(epoch)

    t0 = time.time()
    for iter, input in enumerate(train_loader):
        #torch.cuda.empty_cache()
        if warmup:
            # update lr for each iteration
            lr_scheduler.step()
        x = {
            'cfg': cfg,
            'image': torch.autograd.Variable(input[0]).cuda(),
            'image_info': input[1][:, :3],
            'ground_truth_bboxes': input[2],
            'ignore_regions': None,  # input[3],
            'ground_truth_keypoints': input[4],
            'ground_truth_masks': input[5]
        }
        # for debug
        #debugger.store_tensor_as_image(input[0])
        #debugger.store_filenames(input[-1])
        t1 = time.time()

        outputs = model(x)
        t11 = time.time()

        rpn_cls_loss, rpn_loc_loss, rcnn_cls_loss, rcnn_loc_loss, keypoint_loss = outputs[
            'losses']
        # gradient is averaged by normalizing the loss with world_size
        #loss = (rpn_cls_loss + rpn_loc_loss + rcnn_cls_loss + rcnn_loc_loss + keypoint_loss) / world_size
        loss = sum(outputs['losses']) / world_size
        '''
        if args.dist == 0 or dist.get_rank() == 0:
            graph = vis_helper.make_dot(loss, dict(model.named_parameters()))
            logger.info('PATH:{}'.format(os.environ['PATH']))
            graph.render(filename = 'graph', directory='graph', view=False)
        exit()
        '''
        t12 = time.time()
        lr_scheduler.optimizer.zero_grad()
        loss.backward()
        t13 = time.time()
        if args.dist:
            average_gradients(model)
        t14 = time.time()
        lr_scheduler.optimizer.step()
        t15 = time.time()

        rpn_accuracy = outputs['accuracy'][0][0] / 100.
        rcnn_accuracy = outputs['accuracy'][1][0] / 100.
        loss = loss.data.cpu()[0]
        rpn_cls_loss = rpn_cls_loss.data.cpu()[0]
        rpn_loc_loss = rpn_loc_loss.data.cpu()[0]
        rcnn_cls_loss = rcnn_cls_loss.data.cpu()[0]
        rcnn_loc_loss = rcnn_loc_loss.data.cpu()[0]
        if keypoint_loss is not None:
            keypoint_loss = keypoint_loss.data.cpu()[0]

        t2 = time.time()
        lr = lr_scheduler.get_lr()[0]
        logger.info(
            'Epoch: [%d][%d/%d] LR:%f Time: %.3f Loss: %.5f (rpn_cls: %.5f rpn_loc: %.5f rpn_acc: %.5f'
            ' rcnn_cls: %.5f, rcnn_loc: %.5f rcnn_acc:%.5f kpt:%.5f)' %
            (epoch, iter, len(train_loader), lr, t2 - t0, loss * world_size,
             rpn_cls_loss, rpn_loc_loss, rpn_accuracy, rcnn_cls_loss,
             rcnn_loc_loss, rcnn_accuracy, keypoint_loss))
        t3 = time.time()
        #logger.info('data:{0}, forward:{1}, bp:{2}, sync:{3}, upd:{4}, loss:{5}, prt:{6}'.format(t1-t0, t11-t1, t13-t12, t14-t13, t15-t14, t2-t15, t3-t2))
        #logger.info('data:%f, ' % (t1-t0) +
        #            'forward:%f, ' % (t11-t1) +
        #            'sum_loss:%f, ' % (t12-t11) +
        #            'bp:%f, ' % (t13-t12) +
        #            'sync:%f, ' % (t14-t13) +
        #            'upd:%f, ' % (t15-t14) +
        #            'loss:%f, ' % (t2-t15) +
        #            'prt:%f, ' % (t3-t2))
        print_speed((epoch - 1) * len(train_loader) + iter + 1, t2 - t0,
                    args.epochs * len(train_loader))
        t0 = t2
コード例 #3
0
def train(train_loader, model, criterion, optimizer, lr_scheduler, epoch):
    batch_time = utils.AverageMeter()
    data_time = utils.AverageMeter()
    losses = utils.AverageMeter()
    top1 = utils.AverageMeter()
    top5 = utils.AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()
    for i, (input, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        lr = lr_scheduler.update(i, epoch)

        target = target.cuda()
        input_var = torch.autograd.Variable(input.cuda())
        target_var = torch.autograd.Variable(target)

        # compute output
        output = model(input_var)

        # measure accuracy and record loss
        loss = criterion(output, target_var) / world_size
        prec1, prec5 = utils.accuracy(output.data, target, topk=(1, 5))

        reduced_loss = loss.data.clone()
        reduced_prec1 = prec1.clone() / world_size
        reduced_prec5 = prec5.clone() / world_size

        if args.distribute:
            dist.all_reduce_multigpu([reduced_loss])
            dist.all_reduce_multigpu([reduced_prec1])
            dist.all_reduce_multigpu([reduced_prec5])

        losses.update(reduced_loss.item(), input.size(0))
        top1.update(reduced_prec1.item(), input.size(0))
        top5.update(reduced_prec5.item(), input.size(0))

        # compute gradient and do SGD step
        loss.backward()
        if args.distribute:
            average_gradients(model)
        optimizer.step()
        optimizer.zero_grad()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0 and rank == 0:
            print('Ep: [{0}][{1}/{2}]  '
                  'T {batch_time.val:.2f} ({batch_time.avg:.2f})  '
                  'D {data_time.val:.2f} ({data_time.avg:.2f})  '
                  'LR {lr:.4f}  '
                  'L {loss.val:.3f} ({loss.avg:.4f})  '
                  'P1 {top1.val:.3f} ({top1.avg:.3f})  '
                  'P5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                      epoch,
                      i,
                      len(train_loader),
                      batch_time=batch_time,
                      data_time=data_time,
                      lr=lr,
                      loss=losses,
                      top1=top1,
                      top5=top5))
    return top1.avg, losses.avg
コード例 #4
0
def train(train_loader,
          target_loader,
          val_loader,
          model,
          dec_model,
          dis_model,
          dis_model_patch,
          lr_scheduler,
          lr_scheduler_dec,
          lr_scheduler_dis,
          lr_scheduler_dis_patch,
          epoch,
          cfg,
          warmup=False):
    logger = logging.getLogger('global')
    model.cuda()
    model.train()
    dis_model.cuda()
    dis_model.train()
    dec_model.cuda()
    dec_model.train()
    dis_model_patch.cuda()
    dis_model_patch.train()

    if args.dist:
        rank = dist.get_rank()
        world_size = dist.get_world_size()
    else:
        world_size = 1
        rank = 0

    def freeze_bn(m):
        classname = m.__class__.__name__
        if classname.find('Norm') != -1:
            m.eval()

    model.apply(freeze_bn)
    fix_num = args.fix_num
    count = 1
    for mm in model.modules():
        if count > fix_num:
            break
        if isinstance(mm, torch.nn.Conv2d) and count <= fix_num:
            mm.eval()
            count += 1

    # dec_model.apply(freeze_bn)
    logger.info('freeze bn')

    end = time.time()

    t0 = time.time()
    l1_loss = torch.nn.L1Loss()
    if args.dist:
        # update random seed
        train_loader.sampler.set_epoch(epoch)
        target_loader.sampler.set_epoch(epoch)

    for iter, (input, target) in enumerate(zip(train_loader, target_loader)):
        # torch.cuda.empty_cache()
        if warmup:
            # update lr for each iteration
            lr_scheduler.step()
            lr_scheduler_dis.step()
            lr_scheduler_dec.step()
            lr_scheduler_dis_patch.step()
        x = {
            'cfg': cfg,
            'image': (input[0]).cuda(),
            'image_info': input[1],
            'ground_truth_bboxes': input[2],
            'ignore_regions': None,
            'cluster_num': args.cluster_num,
            'threshold': args.threshold
            # 'ignore_regions': input[3] if args.dataset == 'coco' else None
        }
        target = (target).cuda()
        outputs = model(x, target)

        centers_source, centers_target = outputs['cluster_centers']

        corners_source = get_corner_from_center(centers_source)
        corners_target = get_corner_from_center(centers_target)

        x_small = []
        target_small = []
        for corners_idx in range(0, len(corners_source)):
            x1 = corners_source[corners_idx][0]
            y1 = corners_source[corners_idx][1]
            x2 = corners_source[corners_idx][2]
            y2 = corners_source[corners_idx][3]
            assert (
                x2 -
                x1 == args.recon_size), "x size does not match 256 in source "
            assert (
                y2 -
                y1 == args.recon_size), "y size does not match 256 in source "
            x_small_tmp = x['image'][:, :, y1:y2, x1:x2]
            x_small.append(x_small_tmp)

        x_small = torch.cat(x_small, 0)

        for corners_idx in range(0, len(corners_target)):
            x1 = corners_target[corners_idx][0]
            y1 = corners_target[corners_idx][1]
            x2 = corners_target[corners_idx][2]
            y2 = corners_target[corners_idx][3]
            assert (
                x2 -
                x1 == args.recon_size), "x size does not match 256 in target "
            assert (
                y2 -
                y1 == args.recon_size), "y size does not match 256 in target "
            target_small_tmp = target[:, :, y1:y2, x1:x2]
            target_small.append(target_small_tmp)

        target_small = torch.cat(target_small, 0)  # Size(4, 3, 256, 256)

        x_source_patch, x_target_patch = outputs[
            'cluster_features']  # Size(4, 128, 4096)
        x_source_recon, x_target_recon = dec_model(
            x_source_patch, x_target_patch)  # Size(4, 3, 256, 256)

        ##########################################################################
        ######################### (1): start dis_update ##########################
        ##########################################################################

        lr_scheduler_dis.optimizer.zero_grad()
        x_source_dis, x_target_dis = dis_model(x_source_recon,
                                               x_target_recon)  # (4, 256)
        x_source_real, x_target_real = dis_model(x_small,
                                                 target_small)  # (4, 256)
        x_source_dis = torch.sigmoid(x_source_dis)  # (4, dim)
        x_target_dis = torch.sigmoid(x_target_dis)  # (4, dim)
        x_source_real = torch.sigmoid(x_source_real)  # (4, dim)
        x_target_real = torch.sigmoid(x_target_real)

        x_source_dis_cluster = torch.split(x_source_dis, 1, dim=0)
        x_source_real_cluster = torch.split(x_source_real, 1, dim=0)
        score_1_cluster = generate_soft_label(1, x_source_real_cluster[0])
        score_0_cluster = generate_soft_label(0, x_source_dis_cluster[0])

        adloss_source = 0.0

        #################### (1.1): for source clusters############################
        for clu_idx in range(0, len(x_source_dis_cluster)):
            adloss_source += (
                F.binary_cross_entropy(x_source_dis_cluster[clu_idx],
                                       score_1_cluster) +
                F.binary_cross_entropy(x_source_real_cluster[clu_idx],
                                       score_0_cluster))

        #################### (1.2): for target clusters############################

        x_target_patch_pro = dis_model_patch(x_target_patch)
        x_target_patch_pro_mean = torch.mean(x_target_patch_pro,
                                             1)  # Size(4,1)
        x_source_patch_pro = dis_model_patch(x_source_patch)  # (4, 512)

        x_target_dis_cluster = torch.split(x_target_dis, 1, dim=0)
        x_target_real_cluster = torch.split(x_target_real, 1, dim=0)

        adloss_target = 0.0
        for clu_idx in range(0, len(x_target_dis_cluster)):
            adloss_target += (
                x_target_patch_pro_mean[clu_idx] * F.binary_cross_entropy(
                    x_target_dis_cluster[clu_idx], score_0_cluster) +
                F.binary_cross_entropy(x_target_real_cluster[clu_idx],
                                       score_1_cluster))

        adloss = (adloss_source + adloss_target) / world_size
        adloss.backward(retain_graph=True)

        max_grad3 = 0.0
        for pp in dis_model.parameters():
            tmp = torch.max(pp.grad.data)
            if max_grad3 < tmp:
                max_grad3 = tmp

        if args.dist:
            average_gradients(dis_model)

        lr_scheduler_dis.optimizer.step()

        ##########################################################################
        ####################### (2): start dis_patch_update ######################
        ##########################################################################

        lr_scheduler_dis_patch.optimizer.zero_grad()
        score_0_patch = generate_soft_label(0, x_target_patch_pro)
        score_1_patch = generate_soft_label(1, x_source_patch_pro)

        patch_loss_target = F.binary_cross_entropy(x_target_patch_pro,
                                                   score_0_patch)
        patch_loss_source = F.binary_cross_entropy(x_source_patch_pro,
                                                   score_1_patch)

        dis_patch_loss = (patch_loss_source + patch_loss_target) / world_size
        dis_patch_loss.backward(retain_graph=True)
        if args.dist:
            average_gradients(dis_model_patch)

        lr_scheduler_dis_patch.optimizer.step()

        ##########################################################################
        ########################## (3): start decoder_update #####################
        ##########################################################################

        lr_scheduler_dec.optimizer.zero_grad()

        # x_source_recon, x_target_recon = dec_model(x_target_patch, x_source_patch)
        x_source_dis, x_target_dis = dis_model(x_source_recon, x_target_recon)
        x_source_dis = torch.sigmoid(x_source_dis)  # (4, dim)

        x_source_real, x_target_real = dis_model(x_small,
                                                 target_small)  # (4, 256)
        x_source_real = torch.sigmoid(x_source_real)
        x_target_real = torch.sigmoid(x_target_real)
        """
        for the patch loss of Target image
        """
        x_target_patch_pro = dis_model_patch(x_target_patch)  # size(4, dim)
        gtav_dis_sigmoid_target = torch.sigmoid(x_target_dis)
        """
        obtain the weighting factor of target patches and calculate the target loss
        """
        x_target_patch_pro_mean2 = torch.mean(x_target_patch_pro,
                                              1)  # Size(4,1)
        fake_loss1_target = 0.0
        gtav_dis_sigmoid_target = torch.split(gtav_dis_sigmoid_target,
                                              1,
                                              dim=0)
        # allone_target_1 = (torch.ones(gtav_dis_sigmoid_target[0].size()).float().cuda())
        all_target_1 = generate_hard_label(1, gtav_dis_sigmoid_target[0])

        gtav_real_sigmoid_target = torch.split(x_target_real, 1, dim=0)
        all_target_0 = generate_hard_label(0, gtav_real_sigmoid_target[0])

        for clu_idx in range(0, len(gtav_dis_sigmoid_target)):
            fake_loss1_target += x_target_patch_pro_mean2[clu_idx] * (
                F.binary_cross_entropy(gtav_dis_sigmoid_target[clu_idx],
                                       all_target_1) +
                F.binary_cross_entropy(gtav_real_sigmoid_target[clu_idx],
                                       all_target_0))

        fake_loss1_source = 0.0
        x_source_fake_cluster2 = torch.split(x_source_dis, 1, dim=0)
        all_source_1 = generate_hard_label(1, x_source_fake_cluster2[0])
        x_source_real_cluster2 = torch.split(x_source_real, 1, dim=0)
        all_source_0 = generate_hard_label(0, x_source_real_cluster2[0])

        for clu_idx in range(0, len(x_source_fake_cluster2)):
            fake_loss1_source += (
                F.binary_cross_entropy(x_source_fake_cluster2[clu_idx],
                                       all_source_1) +
                F.binary_cross_entropy(x_source_real_cluster2[clu_idx],
                                       all_source_0))

        recon_loss = (fake_loss1_source + fake_loss1_target
                      ) / world_size  # no-discriminator in the Decoder

        # recon_loss = recon_loss
        recon_loss.backward(retain_graph=True)

        max_grad2 = 0.0
        for pp in dec_model.parameters():
            tmp = torch.max(pp.grad.data)
            if max_grad2 < tmp:
                max_grad2 = tmp

        if args.dist:
            average_gradients(dec_model)
        # torch.nn.utils.clip_grad_norm(dec_model.parameters(), 10.0)
        lr_scheduler_dec.optimizer.step()

        ##########################################################################
        ########################### (4): start detection_update ##################
        ##########################################################################
        """
        target feature maps --> source reconstruction
        for cross-domain alignment
        """
        x_source_recon, x_target_recon = dec_model(x_target_patch,
                                                   x_source_patch)
        x_source_dis, x_target_dis = dis_model(x_source_recon,
                                               x_target_recon)  # (4, dim)
        """
        weight of target patches
        """
        x_fake_dis_sigmoid = torch.sigmoid(x_target_dis)  #
        allone_11 = generate_hard_label(1, x_fake_dis_sigmoid)
        fake_loss_source = F.binary_cross_entropy(
            x_fake_dis_sigmoid, allone_11)  # NO discriminator in Detection
        x_fake_dis_sigmoid2 = torch.sigmoid(x_source_dis)

        x_fake_dis_sigmoid2_cluster = torch.split(x_fake_dis_sigmoid2,
                                                  1,
                                                  dim=0)
        allone_11_cluster = (torch.ones(
            x_fake_dis_sigmoid2_cluster[0].size()).float().cuda())

        fake_loss_target = 0.0
        for clu_idx in range(0, len(x_fake_dis_sigmoid2_cluster)):
            fake_loss_target += x_target_patch_pro_mean2[clu_idx] * (
                F.binary_cross_entropy(x_fake_dis_sigmoid2_cluster[clu_idx],
                                       allone_11_cluster))

        rpn_cls_loss, rpn_loc_loss, rcnn_cls_loss, rcnn_loc_loss = outputs[
            'losses']
        # gradient is averaged by normalizing the loss with world_size
        loss = (rpn_cls_loss + rpn_loc_loss + rcnn_cls_loss + rcnn_loc_loss +
                0.1 * (fake_loss_source + fake_loss_target)) / world_size

        lr_scheduler.optimizer.zero_grad()
        loss.backward()

        max_grad1 = 0.0
        for pp in model.parameters():
            tmp = torch.max(pp.grad.data)
            if max_grad1 < tmp:
                max_grad1 = tmp

        if args.dist:
            average_gradients(model)
        # torch.nn.utils.clip_grad_norm(model.parameters(), 1.0)
        lr_scheduler.optimizer.step()

        ##########################################################################
        ################################ Output information ######################
        ##########################################################################

        rpn_accuracy = outputs['accuracy'][0][0] / 100.
        rcnn_accuracy = outputs['accuracy'][1][0] / 100.

        t2 = time.time()
        lr = lr_scheduler.get_lr()[0]
        logger.info(
            'Epoch: [%d][%d/%d] LR:%f Time: %.3f Loss: %.5f (rpn_cls: %.5f rpn_loc: %.5f rpn_acc: %.5f'
            ' rcnn_cls: %.5f, rcnn_loc: %.5f rcnn_acc:%.5f fake_loss: %.5f dec_loss: %.5f  dis_loss: %.5f fake_loss1: %.5f)'
            % (epoch, iter, len(train_loader), lr, t2 - t0,
               loss.item() * world_size, rpn_cls_loss.item(),
               rpn_loc_loss.item(), rpn_accuracy, rcnn_cls_loss.item(),
               rcnn_loc_loss.item(), rcnn_accuracy, fake_loss_target.item(),
               recon_loss.item(), adloss.item(), fake_loss1_source.item()))
        print_speed((epoch - 1) * len(train_loader) + iter + 1, t2 - t0,
                    args.epochs * len(train_loader))
        t0 = t2
        logger.info("Max Grad, Det: %5f, Dec: %5f, Dis: %5f" %
                    (max_grad1, max_grad2, max_grad3))