Beispiel #1
0
def train(train_loader, model, model_ppm, criterion, optimizer, epoch, zoom_factor, batch_size, aux_weight):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    r_loss_meter = AverageMeter()
    t_loss_meter = AverageMeter()
    loss_meter = AverageMeter()
    


    model.train()
    model_ppm.train()
    world_size = dist.get_world_size()
    rank = dist.get_rank()
    #print(rank)
    end = time.time()
    for i, (input1, input2, translation, quaternions) in enumerate(train_loader):
        # to avoid bn problem in ppm module with bin size 1x1, sometimes n may get 1 on one gpu during the last batch, so just discard
        # if input.shape[0] < batch_size:
        #     continue
        data_time.update(time.time() - end)
        current_iter = (epoch - 1) * len(train_loader) + i + 1
        max_iter = args.epochs * len(train_loader)
        index_split = 4
        poly_learning_rate(optimizer, args.base_lr, current_iter, max_iter, power=args.power, index_split=index_split)

        input1 = input1.cuda()
        input_var1 = torch.autograd.Variable(input1)
        input2 = input2.cuda()
        input_var2 = torch.autograd.Variable(input2)
        x1_ICR, x1_PFR, x1_PRP = model(input_var1)
        x2_ICR, x2_PFR, x2_PRP = model(input_var2)
 
        x1_ICR = (x1_ICR + x1_PFR + x1_PRP)/3
        x2_ICR = (x2_ICR + x2_PFR + x2_PRP)/3
        trans, quat = model_ppm(torch.cat([x1_ICR,x2_ICR], 1))

        
        translation = translation.float().cuda(async=True)
        translation_var = torch.autograd.Variable(translation)
        quaternions = quaternions.float().cuda(async=True)
        quaternions_var = torch.autograd.Variable(quaternions)
        
        t_loss = criterion(trans, translation_var) / world_size
        r_loss = criterion(quat, quaternions_var) / world_size
        loss = r_loss + t_loss

        optimizer.zero_grad()
        loss.backward()
        average_gradients(model)
        optimizer.step()

        reduced_loss = loss.data.clone()
        reduced_t_loss = t_loss.data.clone()
        reduced_r_loss = r_loss.data.clone()
        dist.all_reduce(reduced_loss)
        dist.all_reduce(reduced_t_loss)
        dist.all_reduce(reduced_r_loss)

        r_loss_meter.update(reduced_r_loss[0], input1.size(0))
        t_loss_meter.update(reduced_t_loss[0], input1.size(0))
        loss_meter.update(reduced_loss[0], input1.size(0))
        batch_time.update(time.time() - end)
        end = time.time()

        # calculate remain time
        remain_iter = max_iter - current_iter
        remain_time = remain_iter * batch_time.avg
        t_m, t_s = divmod(remain_time, 60)
        t_h, t_m = divmod(t_m, 60)
        remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m), int(t_s))

        if rank == 0:
            if (i + 1) % args.print_freq == 0:
                logger.info('Epoch: [{}/{}][{}/{}] '
                            'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                            'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) '
                            'Remain {remain_time} '
                            'rLoss {r_loss_meter.val:.4f} '
                            'tLoss {t_loss_meter.val:.4f} '
                            'Loss {loss_meter.val:.4f} '.format(epoch, args.epochs, i + 1, len(train_loader),
                                                              batch_time=batch_time,
                                                              data_time=data_time,
                                                              remain_time=remain_time,
                                                              t_loss_meter=t_loss_meter,
                                                              r_loss_meter=r_loss_meter,
                                                              loss_meter=loss_meter))
            writer.add_scalar('loss_train_batch_r', r_loss_meter.val, current_iter)
            writer.add_scalar('loss_train_batch_t', t_loss_meter.val, current_iter)


    if rank == 0:
        logger.info('Train result at epoch [{}/{}].'.format(epoch, args.epochs))
    return t_loss_meter.avg, r_loss_meter.avg
def train(train_loader, model, criterion, optimizer, epoch, zoom_factor, batch_size, aux_weight, fcw):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    main_loss_meter = AverageMeter()
    aux_loss_meter = AverageMeter()
    loss_meter = AverageMeter()
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()
    target_meter = AverageMeter()

    model.train()
    fcw.train()
    world_size = dist.get_world_size()
    rank = dist.get_rank()
    end = time.time()
    for i, (input, target) in enumerate(train_loader):
        # to avoid bn problem in ppm module with bin size 1x1, sometimes n may get 1 on one gpu during the last batch, so just discard
        # if input.shape[0] < batch_size:
        #     continue
        data_time.update(time.time() - end)
        current_iter = (epoch - 1) * len(train_loader) + i + 1
        max_iter = args.epochs * len(train_loader)
        if args.net_type == 0:
            index_split = 4
        elif args.net_type in [1, 2, 3]:
            index_split = 5
        poly_learning_rate(optimizer, args.base_lr, current_iter, max_iter, power=args.power, index_split=index_split)

        input = input.cuda()
        input_var = torch.autograd.Variable(input)
        fcw_input = fcw(input_var)
        output, aux = model(input_var, fcw_input)

        if zoom_factor != 8:
            h = int((target.size()[1]-1) / 8 * zoom_factor + 1)
            w = int((target.size()[2] - 1) / 8 * zoom_factor + 1)
            # 'nearest' mode doesn't support downsampling operation and while 'bilinear' mode is also fine
            target = F.upsample(target.unsqueeze(1).float(), size=(h, w), mode='bilinear').squeeze(1).long()
        target = target.data.cuda(async=True)
        target_var = torch.autograd.Variable(target)
        main_loss = criterion(output, target_var) / world_size
        aux_loss = criterion(aux, target_var) / world_size
        loss = main_loss + aux_weight * aux_loss

        optimizer.zero_grad()
        loss.backward()
        average_gradients(model)
        optimizer.step()

        output = output.data.max(1)[1].cpu().numpy()
        target = target.cpu().numpy()
        intersection, union, target = intersectionAndUnion(output, target, args.classes, args.ignore_label)
        intersection_meter.update(intersection)
        union_meter.update(union)
        target_meter.update(target)

        accuracy = sum(intersection_meter.val) / (sum(target_meter.val) + 1e-10)

        reduced_loss = loss.data.clone()
        reduced_main_loss = main_loss.data.clone()
        reduced_aux_loss = aux_loss.data.clone()
        dist.all_reduce(reduced_loss)
        dist.all_reduce(reduced_main_loss)
        dist.all_reduce(reduced_aux_loss)

        main_loss_meter.update(reduced_main_loss[0], input.size(0))
        aux_loss_meter.update(reduced_aux_loss[0], input.size(0))
        loss_meter.update(reduced_loss[0], input.size(0))
        batch_time.update(time.time() - end)
        end = time.time()

        # calculate remain time
        remain_iter = max_iter - current_iter
        remain_time = remain_iter * batch_time.avg
        t_m, t_s = divmod(remain_time, 60)
        t_h, t_m = divmod(t_m, 60)
        remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m), int(t_s))

        if rank == 0:
            if (i + 1) % args.print_freq == 0:
                logger.info('Epoch: [{}/{}][{}/{}] '
                            'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                            'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) '
                            'Remain {remain_time} '
                            'MainLoss {main_loss_meter.val:.4f} '
                            'AuxLoss {aux_loss_meter.val:.4f} '
                            'Loss {loss_meter.val:.4f} '
                            'Accuracy {accuracy:.4f}.'.format(epoch, args.epochs, i + 1, len(train_loader),
                                                              batch_time=batch_time,
                                                              data_time=data_time,
                                                              remain_time=remain_time,
                                                              main_loss_meter=main_loss_meter,
                                                              aux_loss_meter=aux_loss_meter,
                                                              loss_meter=loss_meter,
                                                              accuracy=accuracy))
            writer.add_scalar('loss_train_batch', main_loss_meter.val, current_iter)
            writer.add_scalar('mIoU_train_batch', np.mean(intersection / (union + 1e-10)), current_iter)
            writer.add_scalar('mAcc_train_batch', np.mean(intersection / (target + 1e-10)), current_iter)
            writer.add_scalar('allAcc_train_batch', accuracy, current_iter)

    iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
    accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10)
    mIoU = np.mean(iou_class)
    mAcc = np.mean(accuracy_class)
    allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10)
    if rank == 0:
        logger.info('Train result at epoch [{}/{}]: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.'.format(epoch, args.epochs, mIoU, mAcc, allAcc))
    return main_loss_meter.avg, mIoU, mAcc, allAcc
Beispiel #3
0
def train(train_loader, model, criterion, optimizer, epoch, zoom_factor,
          batch_size, aux_weight):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    main_loss_meter = AverageMeter()
    aux_loss_meter = AverageMeter()
    loss_meter = AverageMeter()
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()
    target_meter = AverageMeter()

    model.train()
    world_size = dist.get_world_size()
    rank = dist.get_rank()
    end = time.time()
    for i, (_, input, atts) in enumerate(train_loader):
        # to avoid bn problem in ppm module with bin size 1x1, sometimes n may get 1 on one gpu during the last batch, so just discard
        # if input.shape[0] < batch_size:
        #     continue
        data_time.update(time.time() - end)
        current_iter = (epoch - 1) * len(train_loader) + i + 1
        max_iter = args.epochs * len(train_loader)
        index_split = 4
        poly_learning_rate(optimizer,
                           args.base_lr,
                           current_iter,
                           max_iter,
                           power=args.power,
                           index_split=index_split)

        input = input.cuda()
        input_var = torch.autograd.Variable(input)
        output = model(input_var).squeeze(3).squeeze(2)  #, aux

        #     if zoom_factor != 8:
        #         h = int((target.size()[1] - 1) / 8 * zoom_factor + 1)
        #         w = int((target.size()[2] - 1) / 8 * zoom_factor + 1)
        #         target = F.upsample(target.unsqueeze(1).float(), size=(h, w), mode='bilinear').squeeze(1).long()
        #     target = target.data.cuda(async=True)
        #     target_var = torch.autograd.Variable(target)
        atts_var = torch.autograd.Variable(atts.cuda(async=True))
        main_loss = F.multilabel_soft_margin_loss(output,
                                                  atts_var) / world_size

        #     aux_loss = criterion(aux, target_var) / world_size
        loss = main_loss  # + aux_weight * aux_loss

        optimizer.zero_grad()
        loss.backward()
        average_gradients(model)
        optimizer.step()

        reduced_loss = loss.data.clone()
        reduced_main_loss = main_loss.data.clone()
        dist.all_reduce(reduced_loss)
        dist.all_reduce(reduced_main_loss)

        main_loss_meter.update(reduced_main_loss[0], input.size(0))
        loss_meter.update(reduced_loss[0], input.size(0))
        batch_time.update(time.time() - end)
        end = time.time()

        # calculate remain time
        remain_iter = max_iter - current_iter
        remain_time = remain_iter * batch_time.avg
        t_m, t_s = divmod(remain_time, 60)
        t_h, t_m = divmod(t_m, 60)
        remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m),
                                                    int(t_s))

        if rank == 0:
            if (i + 1) % args.print_freq == 0:
                logger.info(
                    'Epoch: [{}/{}][{}/{}] '
                    'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                    'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) '
                    'Remain {remain_time} '
                    'MainLoss {main_loss_meter.val:.4f} '
                    'Loss {loss_meter.val:.4f} '.format(
                        epoch,
                        args.epochs,
                        i + 1,
                        len(train_loader),
                        batch_time=batch_time,
                        data_time=data_time,
                        remain_time=remain_time,
                        main_loss_meter=main_loss_meter,
                        loss_meter=loss_meter))
            writer.add_scalar('loss_train_batch', main_loss_meter.val,
                              current_iter)

    if rank == 0:
        logger.info('Train result at epoch [{}/{}]: '.format(
            epoch, args.epochs))
    return main_loss_meter.avg
Beispiel #4
0
def train(train_loader, model, optimizer, epoch, batch_size):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    main_loss_meter = AverageMeter()
    loss_meter = AverageMeter()

    model.train()
    world_size = dist.get_world_size()
    rank = dist.get_rank()
    end = time.time()
    for i, (_, input, target) in enumerate(train_loader):
        # to avoid bn problem in ppm module with bin size 1x1, sometimes n may get 1 on one gpu during the last batch, so just discard
        # if input.shape[0] < batch_size:
        #     continue
        data_time.update(time.time() - end)
        current_iter = (epoch - 1) * len(train_loader) + i + 1
        max_iter = args.epochs * len(train_loader)
        poly_learning_rate(optimizer,
                           args.base_lr,
                           current_iter,
                           max_iter,
                           power=args.power,
                           index_split=4)

        input = input.cuda()
        input_var = torch.autograd.Variable(input)
        output1, output3, output6, output9 = model(input_var)
        output1 = output1.squeeze(3).squeeze(2)
        output3 = output3.squeeze(3).squeeze(2)
        output6 = output6.squeeze(3).squeeze(2)
        output9 = output9.squeeze(3).squeeze(2)

        target = target.cuda(async=True)
        target_var = torch.autograd.Variable(target)

        main_loss = (
            F.multilabel_soft_margin_loss(output1, target_var) +
            F.multilabel_soft_margin_loss(output3, target_var) +
            F.multilabel_soft_margin_loss(output6, target_var) +
            F.multilabel_soft_margin_loss(output9, target_var)) / world_size

        loss = main_loss
        optimizer.zero_grad()
        loss.backward()
        average_gradients(model)
        optimizer.step()

        reduced_loss = loss.data.clone()
        reduced_main_loss = main_loss.data.clone()
        dist.all_reduce(reduced_loss)
        dist.all_reduce(reduced_main_loss)

        main_loss_meter.update(reduced_main_loss[0], input.size(0))
        loss_meter.update(reduced_loss[0], input.size(0))
        batch_time.update(time.time() - end)
        end = time.time()

        # calculate remain time
        remain_iter = max_iter - current_iter
        remain_time = remain_iter * batch_time.avg
        t_m, t_s = divmod(remain_time, 60)
        t_h, t_m = divmod(t_m, 60)
        remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m),
                                                    int(t_s))

        if rank == 0:
            if (i + 1) % args.print_freq == 0:
                logger.info(
                    'Epoch: [{}/{}][{}/{}] '
                    'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                    'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) '
                    'Remain {remain_time} '
                    'MainLoss {main_loss_meter.val:.4f} '
                    'Loss {loss_meter.val:.4f} '.format(
                        epoch,
                        args.epochs,
                        i + 1,
                        len(train_loader),
                        batch_time=batch_time,
                        data_time=data_time,
                        remain_time=remain_time,
                        main_loss_meter=main_loss_meter,
                        loss_meter=loss_meter))

    if rank == 0:
        logger.info('Train result at epoch [{}/{}]'.format(epoch, args.epochs))
    return main_loss_meter.avg
Beispiel #5
0
def train(train_loader, train_transform, model, model_icr, model_pfr,
          model_prp, TripletLoss, criterion, optimizer, epoch, zoom_factor,
          batch_size, aux_weight):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    r_loss_meter = AverageMeter()
    t_loss_meter = AverageMeter()
    proj_loss_meter = AverageMeter()
    triple_loss_meter = AverageMeter()
    rfine_loss_meter = AverageMeter()
    tfine_loss_meter = AverageMeter()
    loss_meter = AverageMeter()

    model.train()
    model_icr.train()
    model_pfr.train()
    model_prp.train()
    world_size = dist.get_world_size()
    rank = dist.get_rank()
    #print(rank)
    end = time.time()
    for index, (image_anchor, image1, image2, image3, relative_t1, relative_r1,
                relative_t2, relative_r2, relative_t3, relative_r3, image1_r,
                image2_r, image3_r, anchor_name, absolute_r1, absolute_t1,
                absolute_r2, absolute_t2, absolute_r3, absolute_t3,
                absolute_ranchor, absolute_tanchor) in enumerate(train_loader):
        # to avoid bn problem in ppm module with bin size 1x1, sometimes n may get 1 on one gpu during the last batch, so just discard
        # if input.shape[0] < batch_size:
        #     continue
        data_time.update(time.time() - end)
        current_iter = (epoch - 1) * len(train_loader) + index + 1
        max_iter = args.epochs * len(train_loader)
        index_split = 4
        poly_learning_rate(optimizer,
                           args.base_lr,
                           current_iter,
                           max_iter,
                           power=args.power,
                           index_split=index_split)

        #         print(image_anchor.size())
        image_anchor = torch.cat([image_anchor, image_anchor, image_anchor], 0)
        image1 = torch.cat([image1, image2, image3], 0)
        relative_t1 = torch.cat([relative_t1, relative_t2, relative_t3], 0)
        relative_r1 = torch.cat([relative_r1, relative_r2, relative_r3], 0)
        image1_r = torch.cat([image1_r, image2_r, image3_r], 0)
        #         print(image_anchor.size())

        image_anchor = image_anchor.cuda()
        image_anchor_var = torch.autograd.Variable(image_anchor)
        image1 = image1.cuda()
        image1_var = torch.autograd.Variable(image1)
        x1_ICR, x1_PFR, x1_PRP = model(image_anchor_var)
        x2_ICR, x2_PFR, x2_PRP = model(image1_var)

        proj_ICR = model_icr(torch.cat([x1_ICR, x2_ICR], 1))
        trans, quat = model_pfr(torch.cat([x1_PFR, x2_PFR], 1))

        translation = relative_t1.float().cuda(async=True)
        translation_var = torch.autograd.Variable(translation)
        quaternions = relative_r1.float().cuda(async=True)
        quaternions_var = torch.autograd.Variable(quaternions)
        proj = image1_r.float().cuda(async=True)
        proj_var = torch.autograd.Variable(proj)

        triple_loss = TripletLoss(
            x1_ICR.squeeze(3).squeeze(2),
            x2_ICR.squeeze(3).squeeze(2)) / world_size
        t_loss = criterion(trans.squeeze(3).squeeze(2),
                           translation_var) / world_size
        r_loss = criterion(quat.squeeze(3).squeeze(2),
                           quaternions_var) / world_size
        proj_loss = criterion(proj_ICR.squeeze(3).squeeze(2),
                              proj_var) / world_size

        #########################################################################################

        absolute_r1 = torch.cat([absolute_r1, absolute_r2, absolute_r3], 0)
        absolute_t1 = torch.cat([absolute_t1, absolute_t2, absolute_t3], 0)

        course_t = absolute_t1 - trans.squeeze(3).squeeze(2).data.cpu()
        course_r = get_coarse_quaternion(
            absolute_r1,
            quat.squeeze(3).squeeze(2).data.cpu() *
            np.array([1, -1, -1, -1], np.float32))
        course_rt = torch.cat([course_r, course_t], 1).numpy()
        #print(anchor_name, absolute_r1.size(), quat.size(), course_t, course_r)

        name_list = []
        rt_list = []
        for item in anchor_name:
            name_info, rt_info = get_retrival_info(
                item.replace(
                    args.data_root,
                    '/mnt/lustre/dingmingyu/Research/ICCV19/CamNet/scripts/retrival_lists/'
                ))
            name_list.append(name_info)
            rt_list.append(rt_info)

        #print(len(name_list), name_list)

        fine_list = []
        fine_rt_list = []
        r_fine_loss = 0
        t_fine_loss = 0
        for i in range(course_rt.shape[0]):
            distances = pose_distance(course_rt[i:i + 1],
                                      rt_list[i % int(course_rt.shape[0] / 3)])
            num = np.argmin(distances)
            #print(num, distances[num])
            fine_list.append(name_list[i % int(course_rt.shape[0] / 3)][num])
            fine_rt_list.append(rt_list[i % int(course_rt.shape[0] / 3)][num])

        for i in range(len(anchor_name)):
            #print(anchor_name[i], args.data_root + fine_list[i])
            fine_anchor, fine_1, fine_2, fine_3 = cv2.imread(
                anchor_name[i].replace('pose.txt', 'color.png')), cv2.imread(
                    args.data_root +
                    fine_list[i].replace('pose.txt', 'color.png')), cv2.imread(
                        args.data_root + fine_list[i + len(anchor_name)].
                        replace('pose.txt', 'color.png')), cv2.imread(
                            args.data_root +
                            fine_list[i + len(anchor_name) + len(anchor_name)].
                            replace('pose.txt', 'color.png'))
            fine_anchor, fine_1, fine_2, fine_3 = train_transform(
                fine_anchor, fine_1, fine_2, fine_3)
            fine_anchor = torch.cat([
                fine_anchor.unsqueeze(0),
                fine_anchor.unsqueeze(0),
                fine_anchor.unsqueeze(0)
            ], 0)
            fine_1 = torch.cat([
                fine_1.unsqueeze(0),
                fine_2.unsqueeze(0),
                fine_3.unsqueeze(0)
            ], 0)

            fine_anchor_r = absolute_ranchor[i:i + 1]
            fine_anchor_r = torch.cat(
                [fine_anchor_r, fine_anchor_r, fine_anchor_r], 0)
            fine_anchor_t = absolute_tanchor[i:i + 1]
            fine_anchor_t = torch.cat(
                [fine_anchor_t, fine_anchor_t, fine_anchor_t], 0)
            fine_imgs_rt = np.array([
                fine_rt_list[i], fine_rt_list[i + len(anchor_name)],
                fine_rt_list[i + len(anchor_name) + len(anchor_name)]
            ]).astype(np.float32)
            fine_imgs_r = torch.from_numpy(fine_imgs_rt[:, :4])
            fine_imgs_t = torch.from_numpy(fine_imgs_rt[:, 4:])
            #print(fine_anchor_r.size(), fine_anchor_t.size(), fine_imgs_r.size(), fine_imgs_t.size())
            fine_rela_t = fine_imgs_t - fine_anchor_t
            fine_rela_t_var = torch.autograd.Variable(fine_rela_t.cuda())
            fine_anchor_r[:, 1:] *= -1
            fine_rela_r = get_coarse_quaternion(fine_anchor_r, fine_imgs_r)
            fine_rela_r_var = torch.autograd.Variable(fine_rela_r.cuda())
            #print(fine_rela_r.size(), fine_rela_t.size(), fine_rela_r, fine_rela_t)
            fine_anchor_var = torch.autograd.Variable(fine_anchor.cuda())
            fine_imgs_var = torch.autograd.Variable(fine_1.cuda())
            _, _, anchor_PRP = model(fine_anchor_var)
            _, _, imgs_PRP = model(fine_imgs_var)
            trans_PRP, quat_PRP = model_prp(
                torch.cat([anchor_PRP, imgs_PRP], 1))
            r_fine_loss += criterion(
                quat_PRP.squeeze(3).squeeze(2),
                fine_rela_r_var) / world_size / len(anchor_name)
            t_fine_loss += criterion(
                trans_PRP.squeeze(3).squeeze(2),
                fine_rela_t_var) / world_size / len(anchor_name)
            if rank == 0:
                print(anchor_name[i], args.data_root + fine_list[i],
                      fine_list[i + len(anchor_name)],
                      fine_list[i + len(anchor_name) + len(anchor_name)],
                      fine_anchor_r, fine_anchor_t, fine_rela_t, fine_rela_r)
            #print(anchor_name[i], args.data_root + fine_list[i], fine_anchor_r[i], fine_anchor_t[i], fine_rt_list[i], fine_rt_list[i+len(anchor_name)], fine_rt_list[i+len(anchor_name)+len(anchor_name)])

        loss = r_loss + t_loss + proj_loss + triple_loss + r_fine_loss + t_fine_loss

        optimizer.zero_grad()
        loss.backward()
        average_gradients(model)
        optimizer.step()

        reduced_loss = loss.data.clone()
        reduced_t_loss = t_loss.data.clone()
        reduced_r_loss = r_loss.data.clone()
        reduced_proj_loss = proj_loss.data.clone()
        reduced_triple_loss = triple_loss.data.clone()
        reduced_rfine_loss = r_fine_loss.data.clone()
        reduced_tfine_loss = t_fine_loss.data.clone()
        dist.all_reduce(reduced_loss)
        dist.all_reduce(reduced_t_loss)
        dist.all_reduce(reduced_r_loss)
        dist.all_reduce(reduced_proj_loss)
        dist.all_reduce(reduced_triple_loss)
        dist.all_reduce(reduced_rfine_loss)
        dist.all_reduce(reduced_tfine_loss)

        r_loss_meter.update(reduced_r_loss[0], image_anchor.size(0))
        t_loss_meter.update(reduced_t_loss[0], image_anchor.size(0))
        proj_loss_meter.update(reduced_proj_loss[0], image_anchor.size(0))
        triple_loss_meter.update(reduced_triple_loss[0], image_anchor.size(0))
        rfine_loss_meter.update(reduced_rfine_loss[0], image_anchor.size(0))
        tfine_loss_meter.update(reduced_tfine_loss[0], image_anchor.size(0))
        loss_meter.update(reduced_loss[0], image_anchor.size(0))
        batch_time.update(time.time() - end)
        end = time.time()

        # calculate remain time
        remain_iter = max_iter - current_iter
        remain_time = remain_iter * batch_time.avg
        t_m, t_s = divmod(remain_time, 60)
        t_h, t_m = divmod(t_m, 60)
        remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m),
                                                    int(t_s))

        if rank == 0:
            if (index + 1) % args.print_freq == 0:
                logger.info(
                    'Epoch: [{}/{}][{}/{}] '
                    'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                    'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) '
                    'Remain {remain_time} '
                    'rLoss {r_loss_meter.val:.4f} '
                    'tLoss {t_loss_meter.val:.4f} '
                    'projLoss {proj_loss_meter.val:.4f} '
                    'tripleLoss {triple_loss_meter.val:.4f} '
                    'rfineLoss {rfine_loss_meter.val:.4f} '
                    'tfineLoss {tfine_loss_meter.val:.4f} '
                    'Loss {loss_meter.val:.4f} '.format(
                        epoch,
                        args.epochs,
                        index + 1,
                        len(train_loader),
                        batch_time=batch_time,
                        data_time=data_time,
                        remain_time=remain_time,
                        t_loss_meter=t_loss_meter,
                        r_loss_meter=r_loss_meter,
                        proj_loss_meter=proj_loss_meter,
                        triple_loss_meter=triple_loss_meter,
                        rfine_loss_meter=rfine_loss_meter,
                        tfine_loss_meter=tfine_loss_meter,
                        loss_meter=loss_meter))
            writer.add_scalar('loss_train_batch_r', r_loss_meter.val,
                              current_iter)
            writer.add_scalar('loss_train_batch_t', t_loss_meter.val,
                              current_iter)

    if rank == 0:
        logger.info('Train result at epoch [{}/{}].'.format(
            epoch, args.epochs))
    return t_loss_meter.avg, r_loss_meter.avg
def train(train_loader, model, criterion, optimizer, epoch, zoom_factor, batch_size, ins_weight):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    main_loss_meter = AverageMeter()
    # aux_loss_meter = AverageMeter()
    loss_meter = AverageMeter()
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()
    target_meter = AverageMeter()

    model.train()
    world_size = dist.get_world_size()
    rank = dist.get_rank()
    end = time.time()

    for i, (pre_image_patch, pre_aug_mask, pre_ins_mask, flow_patch, inverse_flow_patch, pre_image_patch_2, pre_ins_mask_2, flow_patch_2, inverse_flow_patch_2, image_patch, ins_mask) in enumerate(train_loader):
        
        #abandon the two objs.
        pre_image_patch = None
        pre_ins_mask = None
        
        data_time.update(time.time() - end)
        current_iter = (epoch - 1) * len(train_loader) + i + 1
        max_iter = args.epochs * len(train_loader)
        if args.net_type == 0:
            index_split = 4
        poly_learning_rate(optimizer, args.base_lr, current_iter, max_iter, power=args.power, index_split=index_split)
        # pre image1.
#         print (np.squeeze(pre_aug_mask.cpu().numpy()).shape) # (1, 1, 433, 433)
        pre_tmp_out = np.squeeze(pre_aug_mask.cpu().numpy(), axis=0).transpose((1, 2, 0))
        pre_tmp_out = np.squeeze(label_to_prob(pre_tmp_out, 1))
        
        flow_patch = np.squeeze(flow_patch.cpu().numpy()).transpose((1, 2, 0))
        inverse_flow_patch_numpy = np.squeeze(inverse_flow_patch.cpu().numpy()).transpose((1, 2, 0))
        warp_pred = prob_to_label(flo.get_warp_label(flow_patch, inverse_flow_patch_numpy, pre_tmp_out))
        pre = torch.from_numpy(warp_pred).contiguous().float()
        warped_pred_aug_mask_var = torch.autograd.Variable(torch.unsqueeze(torch.unsqueeze(pre, dim=0), dim=0).cuda(async=True))
        
        inverse_flow_patch_var = torch.autograd.Variable(inverse_flow_patch.cuda(async=True))
        
        pre_input_var = torch.autograd.Variable(pre_image_patch_2.cuda(async=True))
        # input model
        pre_output_ins = model(pre_input_var, warped_pred_aug_mask_var, inverse_flow_patch_var)

        seg_loss = 0
        pre_ins_mask222 = torch.autograd.Variable(pre_ins_mask_2).squeeze(1).long()
        pre_ins_mask_var = torch.autograd.Variable(pre_ins_mask222.cuda(async=True))
#         #debug1
#         import cv2
# #         pre_image_patch_2_debug = np.squeeze(pre_image_patch_2.cpu().numpy()).transpose((1, 2, 0))
#         warped_pred_aug_mask_var_debug = torch.unsqueeze(pre, dim=0).cpu().numpy().transpose((1, 2, 0))
#         pre_ins_mask_var_debug = np.squeeze(pre_ins_mask_2.cpu().numpy())
        
#         warped_pred_aug_mask_var_debug[warped_pred_aug_mask_var_debug == 255] = 0
#         pre_ins_mask_var_debug[pre_ins_mask_var_debug == 255] = 0
# #         cv2.imwrite("debug/"+str(rank)+ "__" + str(i) +"pre_patch.jpg", pre_image_patch_2_debug)
#         cv2.imwrite("debug/"+str(rank)+ "__" + str(i) +"pre_warped_patch.png", warped_pred_aug_mask_var_debug * 255)
#         cv2.imwrite("debug/"+str(rank)+ "__" + str(i) +"pre_ins_mask.png", pre_ins_mask_var_debug * 255)
#         #debug1 over
        ins_loss = criterion(pre_output_ins, pre_ins_mask_var) / world_size
        last_ins_loss = ins_loss
        loss = seg_loss + ins_loss
        loss = loss * 0.8 # tow loss weight.
        
        # current image.
        image_patch_var = torch.autograd.Variable(image_patch.cuda(async=True))
        # pre_output_ins  (1, 2 ,433, 433)
#         pre_output_ins_var = torch.argmax(pre_output_ins, dim=1, keepdim=True).float()
        pre_output_ins_var = torch.argmax(pre_output_ins[0], dim=0, keepdim=True).float()
#         tmp_out = pre_output_ins_var.data.max(1)[1].cpu().numpy().transpose((1,2,0))
        tmp_out = pre_output_ins_var.cpu().numpy().transpose((1,2,0))
#         print (np.unique(tmp_out))
#         cv2.imwrite("debug/"+str(rank)+ "__" + str(i) +"pre_out.png", tmp_out * 255)
        tmp_prob = np.squeeze(label_to_prob(tmp_out, 1))
        
        flow_patch_2 = np.squeeze(flow_patch_2.cpu().numpy()).transpose((1, 2, 0))
        inverse_flow_patch_2_numpy = np.squeeze(inverse_flow_patch_2.cpu().numpy()).transpose((1, 2, 0))
        #warp
        warp_pred = prob_to_label(flo.get_warp_label(flow_patch_2, inverse_flow_patch_2_numpy, tmp_prob))
        pre = torch.from_numpy(warp_pred).contiguous().float()
        warped_pred_aug_mask_var = torch.autograd.Variable(torch.unsqueeze(torch.unsqueeze(pre, dim=0), dim=0).cuda(async=True))
        # pred_aug_mask_var warp to this image.
        inverse_flow_patch_var_2 = torch.autograd.Variable(inverse_flow_patch_2.cuda(async=True))
        
        # input model
        output_ins = model(image_patch_var, warped_pred_aug_mask_var, inverse_flow_patch_var_2)

        ins_mask = torch.autograd.Variable(ins_mask).squeeze(1).long()
        ins_mask_var = torch.autograd.Variable(ins_mask.cuda(async=True))
        #debug2
#         pre_image_patch_2_debug = np.squeeze(image_patch.cpu().numpy()).transpose((1, 2, 0))
#         warped_pred_aug_mask_var_debug = torch.unsqueeze(pre, dim=0).cpu().numpy().transpose((1, 2, 0))
#         pre_ins_mask_var_debug = np.squeeze(ins_mask.cpu().numpy())
#         warped_pred_aug_mask_var_debug[warped_pred_aug_mask_var_debug == 255] = 0
#         pre_ins_mask_var_debug[pre_ins_mask_var_debug == 255] = 0
# #         cv2.imwrite("debug/"+str(rank)+ "__" + str(i) +"now_patch.jpg", pre_image_patch_2_debug)
#         cv2.imwrite("debug/"+str(rank)+ "__" + str(i) +"now_warped_patch.png", warped_pred_aug_mask_var_debug * 255)
#         cv2.imwrite("debug/"+str(rank)+ "__" + str(i) +"now_ins_mask.png", pre_ins_mask_var_debug * 255)
#         #debug2 over
        seg_loss = 0
        ins_loss = criterion(output_ins, ins_mask_var) / world_size
        loss = ins_loss + seg_loss

        optimizer.zero_grad()
        loss.backward()
        average_gradients(model)
        optimizer.step()

        output = output_ins.data.max(1)[1].cpu().numpy()
        target = ins_mask.cpu().numpy()
        intersection, union, target = intersectionAndUnion(output, target, 2, args.ignore_label)  # 1 = args.classes
        intersection_meter.update(intersection)
        union_meter.update(union)
        target_meter.update(target)

        accuracy = sum(intersection_meter.val) / (sum(target_meter.val) + 1e-10)

        reduced_loss = last_ins_loss.data.clone()
        reduced_main_loss = ins_loss.data.clone()
        # reduced_aux_loss = ins_loss.data.clone()  # ins_loss replace here.
        dist.all_reduce(reduced_loss)
        dist.all_reduce(reduced_main_loss)
        # dist.all_reduce(reduced_aux_loss)

        main_loss_meter.update(reduced_main_loss[0], image_patch.size(0))
        # aux_loss_meter.update(reduced_aux_loss[0], input.size(0))
        loss_meter.update(reduced_loss[0], image_patch.size(0))
        batch_time.update(time.time() - end)
        end = time.time()

        # calculate remain time
        remain_iter = max_iter - current_iter
        remain_time = remain_iter * batch_time.avg
        t_m, t_s = divmod(remain_time, 60)
        t_h, t_m = divmod(t_m, 60)
        remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m), int(t_s))

        if rank == 0:
            if (i + 1) % args.print_freq == 0:
                logger.info('Epoch: [{}/{}][{}/{}] '
                            'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                            'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) '
                            'Remain {remain_time} '
                            'cur ins Loss {main_loss_meter.val:.4f} '
                            'pre ins Loss {loss_meter.val:.4f} '
                            'Accuracy {accuracy:.4f}.'.format(epoch, args.epochs, i + 1, len(train_loader),
                                                              batch_time=batch_time,
                                                              data_time=data_time,
                                                              remain_time=remain_time,
                                                              main_loss_meter=main_loss_meter,
                                                              loss_meter=loss_meter,
                                                              accuracy=accuracy))
            writer.add_scalar('loss_train_batch', main_loss_meter.val, current_iter)
            writer.add_scalar('mIoU_train_batch', np.mean(intersection / (union + 1e-10)), current_iter)
            writer.add_scalar('mAcc_train_batch', np.mean(intersection / (target + 1e-10)), current_iter)
            writer.add_scalar('allAcc_train_batch', accuracy, current_iter)

    iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
    accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10)
    mIoU = np.mean(iou_class)
    mAcc = np.mean(accuracy_class)
    allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10)
    if rank == 0:
        logger.info(
            'Train result at epoch [{}/{}]: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.'.format(epoch, args.epochs, mIoU,
                                                                                           mAcc, allAcc))
    return main_loss_meter.avg, mIoU, mAcc, allAcc