예제 #1
0
def validate(val_loader, model, model_ppm, criterion):
    world_size = dist.get_world_size()
    rank = dist.get_rank()
    if rank == 0:
        logger.info('>>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>')
    batch_time = AverageMeter()
    data_time = AverageMeter()
    r_loss_meter = AverageMeter()
    t_loss_meter = AverageMeter()
    loss_meter = AverageMeter()
    


    model.eval()
    model_ppm.eval()
    
    end = time.time()
    
    for i, (input1, input2, translation, quaternions) in enumerate(val_loader):
        data_time.update(time.time() - end)
        input1 = input1.cuda()
        input_var1 = torch.autograd.Variable(input1)
        input2 = input2.cuda()
        input_var2 = torch.autograd.Variable(input2)
        output1= model(input_var1)
        output2 = model(input_var2)

        trans, quat = model_ppm(torch.cat([output1,output2], 1))
        
        translation = translation.float().cuda(async=True)
        translation_var = torch.autograd.Variable(translation)
        quaternions = quaternions.float().cuda(async=True)
        quaternions_var = torch.autograd.Variable(quaternions)
        
        t_loss = criterion(trans, translation_var) / world_size
        r_loss =  criterion(quat, quaternions_var) / world_size
        loss = t_loss +r_loss
        reduced_loss = loss.data.clone()
        reduced_t_loss = t_loss.data.clone()
        reduced_r_loss = r_loss.data.clone()
        dist.all_reduce(reduced_loss)
        dist.all_reduce(reduced_t_loss)
        dist.all_reduce(reduced_r_loss)
        t_loss_meter.update(reduced_t_loss[0], input1.size(0))
        r_loss_meter.update(reduced_r_loss[0], input1.size(0))
        loss_meter.update(reduced_loss[0], input1.size(0))
        
        batch_time.update(time.time() - end)
        end = time.time()


        if (i + 1) % 10 == 0 and rank == 0:
            logger.info('Test: [{}/{}] '
                        'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                        'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) '
                        'tLoss {t_loss_meter.val:.4f} ({t_loss_meter.avg:.4f}) '
                        'rLoss {r_loss_meter.val:.4f} ({r_loss_meter.avg:.4f}) '.format(i + 1, len(val_loader),
                                                          data_time=data_time,
                                                          batch_time=batch_time,
                                                          t_loss_meter=t_loss_meter,
                                                          r_loss_meter=r_loss_meter))
    if rank == 0:
        logger.info('<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<')
    return t_loss_meter.avg, r_loss_meter.avg
예제 #2
0
def validate(val_loader, model, criterion, classes, zoom_factor):
    logger.info('>>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>')
    batch_time = AverageMeter()
    data_time = AverageMeter()
    loss_meter = AverageMeter()
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()
    target_meter = AverageMeter()

    model.eval()
    end = time.time()
    for i, (input, target) in enumerate(val_loader):
        data_time.update(time.time() - end)
        target = target.cuda()#non_blocking=True)
        output = model(input)
        if zoom_factor != 8:
            output = F.upsample(output, size=target.size()[1:], mode='bilinear', align_corners=True)
        loss = criterion(output, target)

        output = output.data.max(1)[1].cpu().numpy()
        target = target.cpu().numpy()
        intersection, union, target = intersectionAndUnion(output, target, args.classes, args.ignore_label)
        intersection_meter.update(intersection)
        union_meter.update(union)
        target_meter.update(target)

        accuracy = sum(intersection_meter.val) / (sum(target_meter.val) + 1e-10)
        loss_meter.update(loss.item(), input.size(0))
        batch_time.update(time.time() - end)
        end = time.time()
        if (i + 1) % 10 == 0:
            logger.info('Test: [{}/{}] '
                        'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                        'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) '
                        'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f}) '
                        'Accuracy {accuracy:.4f}.'.format(i + 1, len(val_loader),
                                                          data_time=data_time,
                                                          batch_time=batch_time,
                                                          loss_meter=loss_meter,
                                                          accuracy=accuracy))

    iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
    accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10)
    mIoU = np.mean(iou_class)
    mAcc = np.mean(accuracy_class)
    allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10)
    logger.info('Val result: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.'.format(mIoU, mAcc, allAcc))

    for i in range(classes):
        logger.info('Class_{} Result: iou/accuracy {:.4f}/{:.4f}.'.format(i, iou_class[i], accuracy_class[i]))
    logger.info('<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<')
    return loss_meter.avg, mIoU, mAcc, allAcc
예제 #3
0
def train(train_loader, model, model_ppm, criterion, optimizer, epoch, zoom_factor, batch_size, aux_weight):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    r_loss_meter = AverageMeter()
    t_loss_meter = AverageMeter()
    loss_meter = AverageMeter()
    


    model.train()
    model_ppm.train()
    world_size = dist.get_world_size()
    rank = dist.get_rank()
    #print(rank)
    end = time.time()
    for i, (input1, input2, translation, quaternions) in enumerate(train_loader):
        # to avoid bn problem in ppm module with bin size 1x1, sometimes n may get 1 on one gpu during the last batch, so just discard
        # if input.shape[0] < batch_size:
        #     continue
        data_time.update(time.time() - end)
        current_iter = (epoch - 1) * len(train_loader) + i + 1
        max_iter = args.epochs * len(train_loader)
        index_split = 4
        poly_learning_rate(optimizer, args.base_lr, current_iter, max_iter, power=args.power, index_split=index_split)

        input1 = input1.cuda()
        input_var1 = torch.autograd.Variable(input1)
        input2 = input2.cuda()
        input_var2 = torch.autograd.Variable(input2)
        x1_ICR, x1_PFR, x1_PRP = model(input_var1)
        x2_ICR, x2_PFR, x2_PRP = model(input_var2)
 
        x1_ICR = (x1_ICR + x1_PFR + x1_PRP)/3
        x2_ICR = (x2_ICR + x2_PFR + x2_PRP)/3
        trans, quat = model_ppm(torch.cat([x1_ICR,x2_ICR], 1))

        
        translation = translation.float().cuda(async=True)
        translation_var = torch.autograd.Variable(translation)
        quaternions = quaternions.float().cuda(async=True)
        quaternions_var = torch.autograd.Variable(quaternions)
        
        t_loss = criterion(trans, translation_var) / world_size
        r_loss = criterion(quat, quaternions_var) / world_size
        loss = r_loss + t_loss

        optimizer.zero_grad()
        loss.backward()
        average_gradients(model)
        optimizer.step()

        reduced_loss = loss.data.clone()
        reduced_t_loss = t_loss.data.clone()
        reduced_r_loss = r_loss.data.clone()
        dist.all_reduce(reduced_loss)
        dist.all_reduce(reduced_t_loss)
        dist.all_reduce(reduced_r_loss)

        r_loss_meter.update(reduced_r_loss[0], input1.size(0))
        t_loss_meter.update(reduced_t_loss[0], input1.size(0))
        loss_meter.update(reduced_loss[0], input1.size(0))
        batch_time.update(time.time() - end)
        end = time.time()

        # calculate remain time
        remain_iter = max_iter - current_iter
        remain_time = remain_iter * batch_time.avg
        t_m, t_s = divmod(remain_time, 60)
        t_h, t_m = divmod(t_m, 60)
        remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m), int(t_s))

        if rank == 0:
            if (i + 1) % args.print_freq == 0:
                logger.info('Epoch: [{}/{}][{}/{}] '
                            'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                            'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) '
                            'Remain {remain_time} '
                            'rLoss {r_loss_meter.val:.4f} '
                            'tLoss {t_loss_meter.val:.4f} '
                            'Loss {loss_meter.val:.4f} '.format(epoch, args.epochs, i + 1, len(train_loader),
                                                              batch_time=batch_time,
                                                              data_time=data_time,
                                                              remain_time=remain_time,
                                                              t_loss_meter=t_loss_meter,
                                                              r_loss_meter=r_loss_meter,
                                                              loss_meter=loss_meter))
            writer.add_scalar('loss_train_batch_r', r_loss_meter.val, current_iter)
            writer.add_scalar('loss_train_batch_t', t_loss_meter.val, current_iter)


    if rank == 0:
        logger.info('Train result at epoch [{}/{}].'.format(epoch, args.epochs))
    return t_loss_meter.avg, r_loss_meter.avg
예제 #4
0
def train(train_loader, model, criterion, optimizer, epoch, zoom_factor,
          batch_size, aux_weight):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    main_loss_meter = AverageMeter()
    aux_loss_meter = AverageMeter()
    loss_meter = AverageMeter()
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()
    target_meter = AverageMeter()

    model.train()
    world_size = dist.get_world_size()
    rank = dist.get_rank()
    end = time.time()
    for i, (_, input, atts) in enumerate(train_loader):
        # to avoid bn problem in ppm module with bin size 1x1, sometimes n may get 1 on one gpu during the last batch, so just discard
        # if input.shape[0] < batch_size:
        #     continue
        data_time.update(time.time() - end)
        current_iter = (epoch - 1) * len(train_loader) + i + 1
        max_iter = args.epochs * len(train_loader)
        index_split = 4
        poly_learning_rate(optimizer,
                           args.base_lr,
                           current_iter,
                           max_iter,
                           power=args.power,
                           index_split=index_split)

        input = input.cuda()
        input_var = torch.autograd.Variable(input)
        output = model(input_var).squeeze(3).squeeze(2)  #, aux

        #     if zoom_factor != 8:
        #         h = int((target.size()[1] - 1) / 8 * zoom_factor + 1)
        #         w = int((target.size()[2] - 1) / 8 * zoom_factor + 1)
        #         target = F.upsample(target.unsqueeze(1).float(), size=(h, w), mode='bilinear').squeeze(1).long()
        #     target = target.data.cuda(async=True)
        #     target_var = torch.autograd.Variable(target)
        atts_var = torch.autograd.Variable(atts.cuda(async=True))
        main_loss = F.multilabel_soft_margin_loss(output,
                                                  atts_var) / world_size

        #     aux_loss = criterion(aux, target_var) / world_size
        loss = main_loss  # + aux_weight * aux_loss

        optimizer.zero_grad()
        loss.backward()
        average_gradients(model)
        optimizer.step()

        reduced_loss = loss.data.clone()
        reduced_main_loss = main_loss.data.clone()
        dist.all_reduce(reduced_loss)
        dist.all_reduce(reduced_main_loss)

        main_loss_meter.update(reduced_main_loss[0], input.size(0))
        loss_meter.update(reduced_loss[0], input.size(0))
        batch_time.update(time.time() - end)
        end = time.time()

        # calculate remain time
        remain_iter = max_iter - current_iter
        remain_time = remain_iter * batch_time.avg
        t_m, t_s = divmod(remain_time, 60)
        t_h, t_m = divmod(t_m, 60)
        remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m),
                                                    int(t_s))

        if rank == 0:
            if (i + 1) % args.print_freq == 0:
                logger.info(
                    'Epoch: [{}/{}][{}/{}] '
                    'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                    'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) '
                    'Remain {remain_time} '
                    'MainLoss {main_loss_meter.val:.4f} '
                    'Loss {loss_meter.val:.4f} '.format(
                        epoch,
                        args.epochs,
                        i + 1,
                        len(train_loader),
                        batch_time=batch_time,
                        data_time=data_time,
                        remain_time=remain_time,
                        main_loss_meter=main_loss_meter,
                        loss_meter=loss_meter))
            writer.add_scalar('loss_train_batch', main_loss_meter.val,
                              current_iter)

    if rank == 0:
        logger.info('Train result at epoch [{}/{}]: '.format(
            epoch, args.epochs))
    return main_loss_meter.avg
예제 #5
0
def train(train_loader, model, criterion, optimizer, epoch, zoom_factor, batch_size, aux_weight, fcw):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    main_loss_meter = AverageMeter()
    aux_loss_meter = AverageMeter()
    loss_meter = AverageMeter()
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()
    target_meter = AverageMeter()

    model.train()
    fcw.train()
    world_size = dist.get_world_size()
    rank = dist.get_rank()
    end = time.time()
    for i, (input, target) in enumerate(train_loader):
        # to avoid bn problem in ppm module with bin size 1x1, sometimes n may get 1 on one gpu during the last batch, so just discard
        # if input.shape[0] < batch_size:
        #     continue
        data_time.update(time.time() - end)
        current_iter = (epoch - 1) * len(train_loader) + i + 1
        max_iter = args.epochs * len(train_loader)
        if args.net_type == 0:
            index_split = 4
        elif args.net_type in [1, 2, 3]:
            index_split = 5
        poly_learning_rate(optimizer, args.base_lr, current_iter, max_iter, power=args.power, index_split=index_split)

        input = input.cuda()
        input_var = torch.autograd.Variable(input)
        fcw_input = fcw(input_var)
        output, aux = model(input_var, fcw_input)

        if zoom_factor != 8:
            h = int((target.size()[1]-1) / 8 * zoom_factor + 1)
            w = int((target.size()[2] - 1) / 8 * zoom_factor + 1)
            # 'nearest' mode doesn't support downsampling operation and while 'bilinear' mode is also fine
            target = F.upsample(target.unsqueeze(1).float(), size=(h, w), mode='bilinear').squeeze(1).long()
        target = target.data.cuda(async=True)
        target_var = torch.autograd.Variable(target)
        main_loss = criterion(output, target_var) / world_size
        aux_loss = criterion(aux, target_var) / world_size
        loss = main_loss + aux_weight * aux_loss

        optimizer.zero_grad()
        loss.backward()
        average_gradients(model)
        optimizer.step()

        output = output.data.max(1)[1].cpu().numpy()
        target = target.cpu().numpy()
        intersection, union, target = intersectionAndUnion(output, target, args.classes, args.ignore_label)
        intersection_meter.update(intersection)
        union_meter.update(union)
        target_meter.update(target)

        accuracy = sum(intersection_meter.val) / (sum(target_meter.val) + 1e-10)

        reduced_loss = loss.data.clone()
        reduced_main_loss = main_loss.data.clone()
        reduced_aux_loss = aux_loss.data.clone()
        dist.all_reduce(reduced_loss)
        dist.all_reduce(reduced_main_loss)
        dist.all_reduce(reduced_aux_loss)

        main_loss_meter.update(reduced_main_loss[0], input.size(0))
        aux_loss_meter.update(reduced_aux_loss[0], input.size(0))
        loss_meter.update(reduced_loss[0], input.size(0))
        batch_time.update(time.time() - end)
        end = time.time()

        # calculate remain time
        remain_iter = max_iter - current_iter
        remain_time = remain_iter * batch_time.avg
        t_m, t_s = divmod(remain_time, 60)
        t_h, t_m = divmod(t_m, 60)
        remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m), int(t_s))

        if rank == 0:
            if (i + 1) % args.print_freq == 0:
                logger.info('Epoch: [{}/{}][{}/{}] '
                            'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                            'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) '
                            'Remain {remain_time} '
                            'MainLoss {main_loss_meter.val:.4f} '
                            'AuxLoss {aux_loss_meter.val:.4f} '
                            'Loss {loss_meter.val:.4f} '
                            'Accuracy {accuracy:.4f}.'.format(epoch, args.epochs, i + 1, len(train_loader),
                                                              batch_time=batch_time,
                                                              data_time=data_time,
                                                              remain_time=remain_time,
                                                              main_loss_meter=main_loss_meter,
                                                              aux_loss_meter=aux_loss_meter,
                                                              loss_meter=loss_meter,
                                                              accuracy=accuracy))
            writer.add_scalar('loss_train_batch', main_loss_meter.val, current_iter)
            writer.add_scalar('mIoU_train_batch', np.mean(intersection / (union + 1e-10)), current_iter)
            writer.add_scalar('mAcc_train_batch', np.mean(intersection / (target + 1e-10)), current_iter)
            writer.add_scalar('allAcc_train_batch', accuracy, current_iter)

    iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
    accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10)
    mIoU = np.mean(iou_class)
    mAcc = np.mean(accuracy_class)
    allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10)
    if rank == 0:
        logger.info('Train result at epoch [{}/{}]: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.'.format(epoch, args.epochs, mIoU, mAcc, allAcc))
    return main_loss_meter.avg, mIoU, mAcc, allAcc
예제 #6
0
def validate_multi(val_loader, model, criterion):
    batch_time = AverageMeter()
    losses = AverageMeter()  # 训练损失
    prec = AverageMeter()  # 准确率
    rec = AverageMeter()  # 召回率

    model.eval()  # 不启用 BatchNormalization 和 Dropout

    end = time.time()
    tp, fp, fn, tn, count = 0, 0, 0, 0, 0
    tp_size, fn_size = 0, 0
    for i, (input2, target) in enumerate(val_loader):
        target = target.cuda(non_blocking=True)
        original_target = target
        target = target.max(dim=1)[0]
        # compute output
        # requires_grad=True: 要求计算梯度 # requires_grad=False: 不要求计算梯度
        # with torch.no_grad(): 不需要计算梯度,也不会进行反向传播  (torch.no_grad()是新版本pytorch中volatile的替代)
        with torch.no_grad():
            output = model(input2)  # 前向传播
            loss = criterion(output, target.float())  # 计算Loss

        # 对top-3指标进行处理
        output_top3 = output
        target_top3 = target

        if i == 0:
            labels_test = target_top3  # 真实标签
            outputs_test = output_top3  # 模型预测值
        else:
            labels_test = torch.cat((labels_test, target_top3), 0)
            outputs_test = torch.cat((outputs_test, output_top3), 0)

        # measure accuracy and record loss
        pred = output.data.gt(0.0).long()  # 类型转换

        tp += (pred + target).eq(2).sum(dim=0)
        fp += (pred - target).eq(1).sum(dim=0)
        fn += (pred - target).eq(-1).sum(dim=0)
        tn += (pred + target).eq(0).sum(dim=0)
        three_pred = pred.unsqueeze(1).expand(-1, 3, -1)  # n, 3, 80
        tp_size += (three_pred + original_target).eq(2).sum(dim=0)
        fn_size += (three_pred - original_target).eq(-1).sum(dim=0)
        count += input2.size(0)

        this_tp = (pred + target).eq(2).sum()
        this_fp = (pred - target).eq(1).sum()
        this_fn = (pred - target).eq(-1).sum()
        this_tn = (pred + target).eq(0).sum()
        this_acc = (this_tp + this_tn).float() / (this_tp + this_tn + this_fp +
                                                  this_fn).float()

        this_prec = this_tp.float() / (this_tp + this_fp).float(
        ) * 100.0 if this_tp + this_fp != 0 else 0.0
        this_rec = this_tp.float() / (this_tp + this_fn).float(
        ) * 100.0 if this_tp + this_fn != 0 else 0.0

        losses.update(float(loss), input2.size(0))
        prec.update(float(this_prec), input2.size(0))
        rec.update(float(this_rec), input2.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)  # 耗时
        end = time.time()

        p_c = [
            float(tp[i].float() / (tp[i] + fp[i]).float()) *
            100.0 if tp[i] > 0 else 0.0 for i in range(len(tp))
        ]
        r_c = [
            float(tp[i].float() / (tp[i] + fn[i]).float()) *
            100.0 if tp[i] > 0 else 0.0 for i in range(len(tp))
        ]
        f_c = [
            2 * p_c[i] * r_c[i] / (p_c[i] + r_c[i]) if tp[i] > 0 else 0.0
            for i in range(len(tp))
        ]

        mean_p_c = sum(p_c) / len(p_c)
        mean_r_c = sum(r_c) / len(r_c)
        mean_f_c = sum(f_c) / len(f_c)

        p_o = tp.sum().float() / (tp + fp).sum().float() * 100.0
        r_o = tp.sum().float() / (tp + fn).sum().float() * 100.0
        f_o = 2 * p_o * r_o / (p_o + r_o)

        if (i + 1) % args.print_freq == 0:
            print('Test: [{0}/{1}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Precision {prec.val:.2f} ({prec.avg:.2f})\t'
                  'Recall {rec.val:.2f} ({rec.avg:.2f})'.format(
                      i + 1,
                      len(val_loader),
                      batch_time=batch_time,
                      loss=losses,
                      prec=prec,
                      rec=rec))
            print(
                'P_C {:.2f} R_C {:.2f} F_C {:.2f} \t P_O {:.2f} R_O {:.2f} F_O {:.2f}'
                .format(mean_p_c, mean_r_c, mean_f_c, p_o, r_o, f_o))
            print()

    print(
        '--------------------------------------------------------------------')
    print("验证集的最终结果为:")
    print(
        ' * P_C {:.2f} R_C {:.2f} F_C {:.2f} \t P_O {:.2f} R_O {:.2f} F_O {:.2f}'
        .format(mean_p_c, mean_r_c, mean_f_c, p_o, r_o, f_o))
    print(
        "-------------------------------top-3----------------------------------"
    )
    # top-3指标 计算
    print("验证集正在计算top-3指标....")
    mAP = calculate_mAP(labels_test, outputs_test)
    pc_top3, rc_top3, f1c_top3, po_top3, ro_top3, f1o_top3 = overall_topk(
        labels_test, outputs_test, 3)
    print(
        ' * P_C {:.2f} R_C {:.2f} F_C {:.2f} \t P_O {:.2f} R_O {:.2f} F_O {:.2f} \t mAP {:.2f}'
        .format(pc_top3, rc_top3, f1c_top3, po_top3, ro_top3, f1o_top3, mAP))
    print()
    test_CF, test_OF = mean_f_c, f_o  # 返回值
    return test_CF, test_OF
예제 #7
0
def train_multi(train_loader, model, criterion, optimizer, epoch,
                iteration_size):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()  # 训练损失
    prec = AverageMeter()  # 准确率
    rec = AverageMeter()  # 召回率

    # switch to train mode
    # 模型训练准备
    model.train()  # 启用 BatchNormalization 和 Dropout
    optimizer.zero_grad()  # 梯度清零
    end = time.time()
    tp, fp, fn, tn, count = 0, 0, 0, 0, 0
    for i, (input2, target) in enumerate(train_loader):
        data_time.update(time.time() - end)  # 数据加载时间
        target = target.cuda(non_blocking=True)
        target = target.max(dim=1)[0]  # target: torch.Size([16, 80])
        # compute output
        output = model(input2)  # 计算模型输出
        # BCE损失函数
        loss = criterion(output, target.float()) * 80.0  #  损失 扩大80倍

        # 对top-3指标进行处理
        output_top3 = output
        target_top3 = target

        if i == 0:
            labels_test = target_top3  # 真实标签
            outputs_test = output_top3  # 模型预测值
        else:
            labels_test = torch.cat((labels_test, target_top3), 0)
            outputs_test = torch.cat((outputs_test, output_top3), 0)

        pred = output.data.gt(0.0).long()  #  list里面,大于0的值为1,小于0的为0

        tp += (pred + target).eq(2).sum(dim=0)
        fp += (pred - target).eq(1).sum(dim=0)
        fn += (pred - target).eq(-1).sum(dim=0)
        tn += (pred + target).eq(0).sum(dim=0)
        count += input2.size(0)

        this_tp = (pred + target).eq(2).sum()
        this_fp = (pred - target).eq(1).sum()
        this_fn = (pred - target).eq(-1).sum()
        this_tn = (pred + target).eq(0).sum()
        this_acc = (this_tp + this_tn).float() / (this_tp + this_tn + this_fp +
                                                  this_fn).float()

        this_prec = this_tp.float() / (this_tp + this_fp).float(
        ) * 100.0 if this_tp + this_fp != 0 else 0.0
        this_rec = this_tp.float() / (this_tp + this_fn).float(
        ) * 100.0 if this_tp + this_fn != 0 else 0.0

        losses.update(float(loss), input2.size(0))
        prec.update(float(this_prec), input2.size(0))
        rec.update(float(this_rec), input2.size(0))
        # compute gradient and do SGD step
        # 反向传播
        loss.backward()

        if i % iteration_size == iteration_size - 1:  # 在GPU=1的情况下,恒满足
            optimizer.step()  # 更新参数
            optimizer.zero_grad()  # 清空梯度
        # measure elapsed time
        batch_time.update(time.time() - end)  # 耗时
        end = time.time()

        # C-P C-R C-F1
        p_c = [
            float(tp[i].float() / (tp[i] + fp[i]).float()) *
            100.0 if tp[i] > 0 else 0.0 for i in range(len(tp))
        ]
        r_c = [
            float(tp[i].float() / (tp[i] + fn[i]).float()) *
            100.0 if tp[i] > 0 else 0.0 for i in range(len(tp))
        ]
        f_c = [
            2 * p_c[i] * r_c[i] / (p_c[i] + r_c[i]) if tp[i] > 0 else 0.0
            for i in range(len(tp))
        ]

        mean_p_c = sum(p_c) / len(p_c)
        mean_r_c = sum(r_c) / len(r_c)
        mean_f_c = sum(f_c) / len(f_c)
        # O-P O-R O-F1
        p_o = tp.sum().float() / (tp + fp).sum().float() * 100.0
        r_o = tp.sum().float() / (tp + fn).sum().float() * 100.0
        f_o = 2 * p_o * r_o / (p_o + r_o)

        if (i + 1) % args.print_freq == 0:  # 频率刷新
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Precision {prec.val:.2f} ({prec.avg:.2f})\t'
                  'Recall {rec.val:.2f} ({rec.avg:.2f})'.format(
                      epoch + 1,
                      i + 1,
                      len(train_loader),
                      batch_time=batch_time,
                      data_time=data_time,
                      loss=losses,
                      prec=prec,
                      rec=rec))
            print(
                '*  all: P_C {:.2f} R_C {:.2f} F_C {:.2f} \t P_O {:.2f} R_O {:.2f} F_O {:.2f}'
                .format(mean_p_c, mean_r_c, mean_f_c, p_o, r_o, f_o))
            print()

    print(
        "-------------------------------top-3----------------------------------"
    )
    # top-3指标 计算
    print("训练集正在计算top-3指标....")
    mAP = calculate_mAP(labels_test, outputs_test)
    pc_top3, rc_top3, f1c_top3, po_top3, ro_top3, f1o_top3 = overall_topk(
        labels_test, outputs_test, 3)
    print(
        ' * top3: P_C {:.2f} R_C {:.2f} F_C {:.2f} \t P_O {:.2f} R_O {:.2f} F_O {:.2f} \t mAP {:.2f}'
        .format(pc_top3, rc_top3, f1c_top3, po_top3, ro_top3, f1o_top3, mAP))
    print()
예제 #8
0
파일: train.py 프로젝트: dingmyu/psa
def train(train_loader, model, optimizer, epoch, batch_size):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    main_loss_meter = AverageMeter()
    loss_meter = AverageMeter()

    model.train()
    world_size = dist.get_world_size()
    rank = dist.get_rank()
    end = time.time()
    for i, (_, input, target) in enumerate(train_loader):
        # to avoid bn problem in ppm module with bin size 1x1, sometimes n may get 1 on one gpu during the last batch, so just discard
        # if input.shape[0] < batch_size:
        #     continue
        data_time.update(time.time() - end)
        current_iter = (epoch - 1) * len(train_loader) + i + 1
        max_iter = args.epochs * len(train_loader)
        poly_learning_rate(optimizer,
                           args.base_lr,
                           current_iter,
                           max_iter,
                           power=args.power,
                           index_split=4)

        input = input.cuda()
        input_var = torch.autograd.Variable(input)
        output1, output3, output6, output9 = model(input_var)
        output1 = output1.squeeze(3).squeeze(2)
        output3 = output3.squeeze(3).squeeze(2)
        output6 = output6.squeeze(3).squeeze(2)
        output9 = output9.squeeze(3).squeeze(2)

        target = target.cuda(async=True)
        target_var = torch.autograd.Variable(target)

        main_loss = (
            F.multilabel_soft_margin_loss(output1, target_var) +
            F.multilabel_soft_margin_loss(output3, target_var) +
            F.multilabel_soft_margin_loss(output6, target_var) +
            F.multilabel_soft_margin_loss(output9, target_var)) / world_size

        loss = main_loss
        optimizer.zero_grad()
        loss.backward()
        average_gradients(model)
        optimizer.step()

        reduced_loss = loss.data.clone()
        reduced_main_loss = main_loss.data.clone()
        dist.all_reduce(reduced_loss)
        dist.all_reduce(reduced_main_loss)

        main_loss_meter.update(reduced_main_loss[0], input.size(0))
        loss_meter.update(reduced_loss[0], input.size(0))
        batch_time.update(time.time() - end)
        end = time.time()

        # calculate remain time
        remain_iter = max_iter - current_iter
        remain_time = remain_iter * batch_time.avg
        t_m, t_s = divmod(remain_time, 60)
        t_h, t_m = divmod(t_m, 60)
        remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m),
                                                    int(t_s))

        if rank == 0:
            if (i + 1) % args.print_freq == 0:
                logger.info(
                    'Epoch: [{}/{}][{}/{}] '
                    'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                    'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) '
                    'Remain {remain_time} '
                    'MainLoss {main_loss_meter.val:.4f} '
                    'Loss {loss_meter.val:.4f} '.format(
                        epoch,
                        args.epochs,
                        i + 1,
                        len(train_loader),
                        batch_time=batch_time,
                        data_time=data_time,
                        remain_time=remain_time,
                        main_loss_meter=main_loss_meter,
                        loss_meter=loss_meter))

    if rank == 0:
        logger.info('Train result at epoch [{}/{}]'.format(epoch, args.epochs))
    return main_loss_meter.avg
예제 #9
0
def train(train_loader, train_transform, model, model_icr, model_pfr,
          model_prp, TripletLoss, criterion, optimizer, epoch, zoom_factor,
          batch_size, aux_weight):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    r_loss_meter = AverageMeter()
    t_loss_meter = AverageMeter()
    proj_loss_meter = AverageMeter()
    triple_loss_meter = AverageMeter()
    rfine_loss_meter = AverageMeter()
    tfine_loss_meter = AverageMeter()
    loss_meter = AverageMeter()

    model.train()
    model_icr.train()
    model_pfr.train()
    model_prp.train()
    world_size = dist.get_world_size()
    rank = dist.get_rank()
    #print(rank)
    end = time.time()
    for index, (image_anchor, image1, image2, image3, relative_t1, relative_r1,
                relative_t2, relative_r2, relative_t3, relative_r3, image1_r,
                image2_r, image3_r, anchor_name, absolute_r1, absolute_t1,
                absolute_r2, absolute_t2, absolute_r3, absolute_t3,
                absolute_ranchor, absolute_tanchor) in enumerate(train_loader):
        # to avoid bn problem in ppm module with bin size 1x1, sometimes n may get 1 on one gpu during the last batch, so just discard
        # if input.shape[0] < batch_size:
        #     continue
        data_time.update(time.time() - end)
        current_iter = (epoch - 1) * len(train_loader) + index + 1
        max_iter = args.epochs * len(train_loader)
        index_split = 4
        poly_learning_rate(optimizer,
                           args.base_lr,
                           current_iter,
                           max_iter,
                           power=args.power,
                           index_split=index_split)

        #         print(image_anchor.size())
        image_anchor = torch.cat([image_anchor, image_anchor, image_anchor], 0)
        image1 = torch.cat([image1, image2, image3], 0)
        relative_t1 = torch.cat([relative_t1, relative_t2, relative_t3], 0)
        relative_r1 = torch.cat([relative_r1, relative_r2, relative_r3], 0)
        image1_r = torch.cat([image1_r, image2_r, image3_r], 0)
        #         print(image_anchor.size())

        image_anchor = image_anchor.cuda()
        image_anchor_var = torch.autograd.Variable(image_anchor)
        image1 = image1.cuda()
        image1_var = torch.autograd.Variable(image1)
        x1_ICR, x1_PFR, x1_PRP = model(image_anchor_var)
        x2_ICR, x2_PFR, x2_PRP = model(image1_var)

        proj_ICR = model_icr(torch.cat([x1_ICR, x2_ICR], 1))
        trans, quat = model_pfr(torch.cat([x1_PFR, x2_PFR], 1))

        translation = relative_t1.float().cuda(async=True)
        translation_var = torch.autograd.Variable(translation)
        quaternions = relative_r1.float().cuda(async=True)
        quaternions_var = torch.autograd.Variable(quaternions)
        proj = image1_r.float().cuda(async=True)
        proj_var = torch.autograd.Variable(proj)

        triple_loss = TripletLoss(
            x1_ICR.squeeze(3).squeeze(2),
            x2_ICR.squeeze(3).squeeze(2)) / world_size
        t_loss = criterion(trans.squeeze(3).squeeze(2),
                           translation_var) / world_size
        r_loss = criterion(quat.squeeze(3).squeeze(2),
                           quaternions_var) / world_size
        proj_loss = criterion(proj_ICR.squeeze(3).squeeze(2),
                              proj_var) / world_size

        #########################################################################################

        absolute_r1 = torch.cat([absolute_r1, absolute_r2, absolute_r3], 0)
        absolute_t1 = torch.cat([absolute_t1, absolute_t2, absolute_t3], 0)

        course_t = absolute_t1 - trans.squeeze(3).squeeze(2).data.cpu()
        course_r = get_coarse_quaternion(
            absolute_r1,
            quat.squeeze(3).squeeze(2).data.cpu() *
            np.array([1, -1, -1, -1], np.float32))
        course_rt = torch.cat([course_r, course_t], 1).numpy()
        #print(anchor_name, absolute_r1.size(), quat.size(), course_t, course_r)

        name_list = []
        rt_list = []
        for item in anchor_name:
            name_info, rt_info = get_retrival_info(
                item.replace(
                    args.data_root,
                    '/mnt/lustre/dingmingyu/Research/ICCV19/CamNet/scripts/retrival_lists/'
                ))
            name_list.append(name_info)
            rt_list.append(rt_info)

        #print(len(name_list), name_list)

        fine_list = []
        fine_rt_list = []
        r_fine_loss = 0
        t_fine_loss = 0
        for i in range(course_rt.shape[0]):
            distances = pose_distance(course_rt[i:i + 1],
                                      rt_list[i % int(course_rt.shape[0] / 3)])
            num = np.argmin(distances)
            #print(num, distances[num])
            fine_list.append(name_list[i % int(course_rt.shape[0] / 3)][num])
            fine_rt_list.append(rt_list[i % int(course_rt.shape[0] / 3)][num])

        for i in range(len(anchor_name)):
            #print(anchor_name[i], args.data_root + fine_list[i])
            fine_anchor, fine_1, fine_2, fine_3 = cv2.imread(
                anchor_name[i].replace('pose.txt', 'color.png')), cv2.imread(
                    args.data_root +
                    fine_list[i].replace('pose.txt', 'color.png')), cv2.imread(
                        args.data_root + fine_list[i + len(anchor_name)].
                        replace('pose.txt', 'color.png')), cv2.imread(
                            args.data_root +
                            fine_list[i + len(anchor_name) + len(anchor_name)].
                            replace('pose.txt', 'color.png'))
            fine_anchor, fine_1, fine_2, fine_3 = train_transform(
                fine_anchor, fine_1, fine_2, fine_3)
            fine_anchor = torch.cat([
                fine_anchor.unsqueeze(0),
                fine_anchor.unsqueeze(0),
                fine_anchor.unsqueeze(0)
            ], 0)
            fine_1 = torch.cat([
                fine_1.unsqueeze(0),
                fine_2.unsqueeze(0),
                fine_3.unsqueeze(0)
            ], 0)

            fine_anchor_r = absolute_ranchor[i:i + 1]
            fine_anchor_r = torch.cat(
                [fine_anchor_r, fine_anchor_r, fine_anchor_r], 0)
            fine_anchor_t = absolute_tanchor[i:i + 1]
            fine_anchor_t = torch.cat(
                [fine_anchor_t, fine_anchor_t, fine_anchor_t], 0)
            fine_imgs_rt = np.array([
                fine_rt_list[i], fine_rt_list[i + len(anchor_name)],
                fine_rt_list[i + len(anchor_name) + len(anchor_name)]
            ]).astype(np.float32)
            fine_imgs_r = torch.from_numpy(fine_imgs_rt[:, :4])
            fine_imgs_t = torch.from_numpy(fine_imgs_rt[:, 4:])
            #print(fine_anchor_r.size(), fine_anchor_t.size(), fine_imgs_r.size(), fine_imgs_t.size())
            fine_rela_t = fine_imgs_t - fine_anchor_t
            fine_rela_t_var = torch.autograd.Variable(fine_rela_t.cuda())
            fine_anchor_r[:, 1:] *= -1
            fine_rela_r = get_coarse_quaternion(fine_anchor_r, fine_imgs_r)
            fine_rela_r_var = torch.autograd.Variable(fine_rela_r.cuda())
            #print(fine_rela_r.size(), fine_rela_t.size(), fine_rela_r, fine_rela_t)
            fine_anchor_var = torch.autograd.Variable(fine_anchor.cuda())
            fine_imgs_var = torch.autograd.Variable(fine_1.cuda())
            _, _, anchor_PRP = model(fine_anchor_var)
            _, _, imgs_PRP = model(fine_imgs_var)
            trans_PRP, quat_PRP = model_prp(
                torch.cat([anchor_PRP, imgs_PRP], 1))
            r_fine_loss += criterion(
                quat_PRP.squeeze(3).squeeze(2),
                fine_rela_r_var) / world_size / len(anchor_name)
            t_fine_loss += criterion(
                trans_PRP.squeeze(3).squeeze(2),
                fine_rela_t_var) / world_size / len(anchor_name)
            if rank == 0:
                print(anchor_name[i], args.data_root + fine_list[i],
                      fine_list[i + len(anchor_name)],
                      fine_list[i + len(anchor_name) + len(anchor_name)],
                      fine_anchor_r, fine_anchor_t, fine_rela_t, fine_rela_r)
            #print(anchor_name[i], args.data_root + fine_list[i], fine_anchor_r[i], fine_anchor_t[i], fine_rt_list[i], fine_rt_list[i+len(anchor_name)], fine_rt_list[i+len(anchor_name)+len(anchor_name)])

        loss = r_loss + t_loss + proj_loss + triple_loss + r_fine_loss + t_fine_loss

        optimizer.zero_grad()
        loss.backward()
        average_gradients(model)
        optimizer.step()

        reduced_loss = loss.data.clone()
        reduced_t_loss = t_loss.data.clone()
        reduced_r_loss = r_loss.data.clone()
        reduced_proj_loss = proj_loss.data.clone()
        reduced_triple_loss = triple_loss.data.clone()
        reduced_rfine_loss = r_fine_loss.data.clone()
        reduced_tfine_loss = t_fine_loss.data.clone()
        dist.all_reduce(reduced_loss)
        dist.all_reduce(reduced_t_loss)
        dist.all_reduce(reduced_r_loss)
        dist.all_reduce(reduced_proj_loss)
        dist.all_reduce(reduced_triple_loss)
        dist.all_reduce(reduced_rfine_loss)
        dist.all_reduce(reduced_tfine_loss)

        r_loss_meter.update(reduced_r_loss[0], image_anchor.size(0))
        t_loss_meter.update(reduced_t_loss[0], image_anchor.size(0))
        proj_loss_meter.update(reduced_proj_loss[0], image_anchor.size(0))
        triple_loss_meter.update(reduced_triple_loss[0], image_anchor.size(0))
        rfine_loss_meter.update(reduced_rfine_loss[0], image_anchor.size(0))
        tfine_loss_meter.update(reduced_tfine_loss[0], image_anchor.size(0))
        loss_meter.update(reduced_loss[0], image_anchor.size(0))
        batch_time.update(time.time() - end)
        end = time.time()

        # calculate remain time
        remain_iter = max_iter - current_iter
        remain_time = remain_iter * batch_time.avg
        t_m, t_s = divmod(remain_time, 60)
        t_h, t_m = divmod(t_m, 60)
        remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m),
                                                    int(t_s))

        if rank == 0:
            if (index + 1) % args.print_freq == 0:
                logger.info(
                    'Epoch: [{}/{}][{}/{}] '
                    'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                    'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) '
                    'Remain {remain_time} '
                    'rLoss {r_loss_meter.val:.4f} '
                    'tLoss {t_loss_meter.val:.4f} '
                    'projLoss {proj_loss_meter.val:.4f} '
                    'tripleLoss {triple_loss_meter.val:.4f} '
                    'rfineLoss {rfine_loss_meter.val:.4f} '
                    'tfineLoss {tfine_loss_meter.val:.4f} '
                    'Loss {loss_meter.val:.4f} '.format(
                        epoch,
                        args.epochs,
                        index + 1,
                        len(train_loader),
                        batch_time=batch_time,
                        data_time=data_time,
                        remain_time=remain_time,
                        t_loss_meter=t_loss_meter,
                        r_loss_meter=r_loss_meter,
                        proj_loss_meter=proj_loss_meter,
                        triple_loss_meter=triple_loss_meter,
                        rfine_loss_meter=rfine_loss_meter,
                        tfine_loss_meter=tfine_loss_meter,
                        loss_meter=loss_meter))
            writer.add_scalar('loss_train_batch_r', r_loss_meter.val,
                              current_iter)
            writer.add_scalar('loss_train_batch_t', t_loss_meter.val,
                              current_iter)

    if rank == 0:
        logger.info('Train result at epoch [{}/{}].'.format(
            epoch, args.epochs))
    return t_loss_meter.avg, r_loss_meter.avg
예제 #10
0
def train(train_loader, model, criterion, optimizer, epoch, zoom_factor, batch_size, ins_weight):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    main_loss_meter = AverageMeter()
    # aux_loss_meter = AverageMeter()
    loss_meter = AverageMeter()
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()
    target_meter = AverageMeter()

    model.train()
    world_size = dist.get_world_size()
    rank = dist.get_rank()
    end = time.time()

    for i, (pre_image_patch, pre_aug_mask, pre_ins_mask, flow_patch, inverse_flow_patch, pre_image_patch_2, pre_ins_mask_2, flow_patch_2, inverse_flow_patch_2, image_patch, ins_mask) in enumerate(train_loader):
        
        #abandon the two objs.
        pre_image_patch = None
        pre_ins_mask = None
        
        data_time.update(time.time() - end)
        current_iter = (epoch - 1) * len(train_loader) + i + 1
        max_iter = args.epochs * len(train_loader)
        if args.net_type == 0:
            index_split = 4
        poly_learning_rate(optimizer, args.base_lr, current_iter, max_iter, power=args.power, index_split=index_split)
        # pre image1.
#         print (np.squeeze(pre_aug_mask.cpu().numpy()).shape) # (1, 1, 433, 433)
        pre_tmp_out = np.squeeze(pre_aug_mask.cpu().numpy(), axis=0).transpose((1, 2, 0))
        pre_tmp_out = np.squeeze(label_to_prob(pre_tmp_out, 1))
        
        flow_patch = np.squeeze(flow_patch.cpu().numpy()).transpose((1, 2, 0))
        inverse_flow_patch_numpy = np.squeeze(inverse_flow_patch.cpu().numpy()).transpose((1, 2, 0))
        warp_pred = prob_to_label(flo.get_warp_label(flow_patch, inverse_flow_patch_numpy, pre_tmp_out))
        pre = torch.from_numpy(warp_pred).contiguous().float()
        warped_pred_aug_mask_var = torch.autograd.Variable(torch.unsqueeze(torch.unsqueeze(pre, dim=0), dim=0).cuda(async=True))
        
        inverse_flow_patch_var = torch.autograd.Variable(inverse_flow_patch.cuda(async=True))
        
        pre_input_var = torch.autograd.Variable(pre_image_patch_2.cuda(async=True))
        # input model
        pre_output_ins = model(pre_input_var, warped_pred_aug_mask_var, inverse_flow_patch_var)

        seg_loss = 0
        pre_ins_mask222 = torch.autograd.Variable(pre_ins_mask_2).squeeze(1).long()
        pre_ins_mask_var = torch.autograd.Variable(pre_ins_mask222.cuda(async=True))
#         #debug1
#         import cv2
# #         pre_image_patch_2_debug = np.squeeze(pre_image_patch_2.cpu().numpy()).transpose((1, 2, 0))
#         warped_pred_aug_mask_var_debug = torch.unsqueeze(pre, dim=0).cpu().numpy().transpose((1, 2, 0))
#         pre_ins_mask_var_debug = np.squeeze(pre_ins_mask_2.cpu().numpy())
        
#         warped_pred_aug_mask_var_debug[warped_pred_aug_mask_var_debug == 255] = 0
#         pre_ins_mask_var_debug[pre_ins_mask_var_debug == 255] = 0
# #         cv2.imwrite("debug/"+str(rank)+ "__" + str(i) +"pre_patch.jpg", pre_image_patch_2_debug)
#         cv2.imwrite("debug/"+str(rank)+ "__" + str(i) +"pre_warped_patch.png", warped_pred_aug_mask_var_debug * 255)
#         cv2.imwrite("debug/"+str(rank)+ "__" + str(i) +"pre_ins_mask.png", pre_ins_mask_var_debug * 255)
#         #debug1 over
        ins_loss = criterion(pre_output_ins, pre_ins_mask_var) / world_size
        last_ins_loss = ins_loss
        loss = seg_loss + ins_loss
        loss = loss * 0.8 # tow loss weight.
        
        # current image.
        image_patch_var = torch.autograd.Variable(image_patch.cuda(async=True))
        # pre_output_ins  (1, 2 ,433, 433)
#         pre_output_ins_var = torch.argmax(pre_output_ins, dim=1, keepdim=True).float()
        pre_output_ins_var = torch.argmax(pre_output_ins[0], dim=0, keepdim=True).float()
#         tmp_out = pre_output_ins_var.data.max(1)[1].cpu().numpy().transpose((1,2,0))
        tmp_out = pre_output_ins_var.cpu().numpy().transpose((1,2,0))
#         print (np.unique(tmp_out))
#         cv2.imwrite("debug/"+str(rank)+ "__" + str(i) +"pre_out.png", tmp_out * 255)
        tmp_prob = np.squeeze(label_to_prob(tmp_out, 1))
        
        flow_patch_2 = np.squeeze(flow_patch_2.cpu().numpy()).transpose((1, 2, 0))
        inverse_flow_patch_2_numpy = np.squeeze(inverse_flow_patch_2.cpu().numpy()).transpose((1, 2, 0))
        #warp
        warp_pred = prob_to_label(flo.get_warp_label(flow_patch_2, inverse_flow_patch_2_numpy, tmp_prob))
        pre = torch.from_numpy(warp_pred).contiguous().float()
        warped_pred_aug_mask_var = torch.autograd.Variable(torch.unsqueeze(torch.unsqueeze(pre, dim=0), dim=0).cuda(async=True))
        # pred_aug_mask_var warp to this image.
        inverse_flow_patch_var_2 = torch.autograd.Variable(inverse_flow_patch_2.cuda(async=True))
        
        # input model
        output_ins = model(image_patch_var, warped_pred_aug_mask_var, inverse_flow_patch_var_2)

        ins_mask = torch.autograd.Variable(ins_mask).squeeze(1).long()
        ins_mask_var = torch.autograd.Variable(ins_mask.cuda(async=True))
        #debug2
#         pre_image_patch_2_debug = np.squeeze(image_patch.cpu().numpy()).transpose((1, 2, 0))
#         warped_pred_aug_mask_var_debug = torch.unsqueeze(pre, dim=0).cpu().numpy().transpose((1, 2, 0))
#         pre_ins_mask_var_debug = np.squeeze(ins_mask.cpu().numpy())
#         warped_pred_aug_mask_var_debug[warped_pred_aug_mask_var_debug == 255] = 0
#         pre_ins_mask_var_debug[pre_ins_mask_var_debug == 255] = 0
# #         cv2.imwrite("debug/"+str(rank)+ "__" + str(i) +"now_patch.jpg", pre_image_patch_2_debug)
#         cv2.imwrite("debug/"+str(rank)+ "__" + str(i) +"now_warped_patch.png", warped_pred_aug_mask_var_debug * 255)
#         cv2.imwrite("debug/"+str(rank)+ "__" + str(i) +"now_ins_mask.png", pre_ins_mask_var_debug * 255)
#         #debug2 over
        seg_loss = 0
        ins_loss = criterion(output_ins, ins_mask_var) / world_size
        loss = ins_loss + seg_loss

        optimizer.zero_grad()
        loss.backward()
        average_gradients(model)
        optimizer.step()

        output = output_ins.data.max(1)[1].cpu().numpy()
        target = ins_mask.cpu().numpy()
        intersection, union, target = intersectionAndUnion(output, target, 2, args.ignore_label)  # 1 = args.classes
        intersection_meter.update(intersection)
        union_meter.update(union)
        target_meter.update(target)

        accuracy = sum(intersection_meter.val) / (sum(target_meter.val) + 1e-10)

        reduced_loss = last_ins_loss.data.clone()
        reduced_main_loss = ins_loss.data.clone()
        # reduced_aux_loss = ins_loss.data.clone()  # ins_loss replace here.
        dist.all_reduce(reduced_loss)
        dist.all_reduce(reduced_main_loss)
        # dist.all_reduce(reduced_aux_loss)

        main_loss_meter.update(reduced_main_loss[0], image_patch.size(0))
        # aux_loss_meter.update(reduced_aux_loss[0], input.size(0))
        loss_meter.update(reduced_loss[0], image_patch.size(0))
        batch_time.update(time.time() - end)
        end = time.time()

        # calculate remain time
        remain_iter = max_iter - current_iter
        remain_time = remain_iter * batch_time.avg
        t_m, t_s = divmod(remain_time, 60)
        t_h, t_m = divmod(t_m, 60)
        remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m), int(t_s))

        if rank == 0:
            if (i + 1) % args.print_freq == 0:
                logger.info('Epoch: [{}/{}][{}/{}] '
                            'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                            'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) '
                            'Remain {remain_time} '
                            'cur ins Loss {main_loss_meter.val:.4f} '
                            'pre ins Loss {loss_meter.val:.4f} '
                            'Accuracy {accuracy:.4f}.'.format(epoch, args.epochs, i + 1, len(train_loader),
                                                              batch_time=batch_time,
                                                              data_time=data_time,
                                                              remain_time=remain_time,
                                                              main_loss_meter=main_loss_meter,
                                                              loss_meter=loss_meter,
                                                              accuracy=accuracy))
            writer.add_scalar('loss_train_batch', main_loss_meter.val, current_iter)
            writer.add_scalar('mIoU_train_batch', np.mean(intersection / (union + 1e-10)), current_iter)
            writer.add_scalar('mAcc_train_batch', np.mean(intersection / (target + 1e-10)), current_iter)
            writer.add_scalar('allAcc_train_batch', accuracy, current_iter)

    iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
    accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10)
    mIoU = np.mean(iou_class)
    mAcc = np.mean(accuracy_class)
    allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10)
    if rank == 0:
        logger.info(
            'Train result at epoch [{}/{}]: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.'.format(epoch, args.epochs, mIoU,
                                                                                           mAcc, allAcc))
    return main_loss_meter.avg, mIoU, mAcc, allAcc