def main():

    seq_length = args.seq_length
    num_scales = 4

    torch.manual_seed(0)

    device = args.device

    disp_net = DispNetS().to(device)
    disp_net.init_weights()

    pose_exp_net = PoseExpNet(nb_ref_imgs=seq_length - 1,
                              output_exp=False).to(device)
    pose_exp_net.init_weights()

    args_lr = args.learning_rate
    optim_params = [{
        'params': disp_net.parameters(),
        'lr': args_lr
    }, {
        'params': pose_exp_net.parameters(),
        'lr': args_lr
    }]

    args_momentum = 0.9
    args_beta = 0.999
    args_weight_decay = 0

    optimizer = torch.optim.Adam(optim_params,
                                 betas=(args_momentum, args_beta),
                                 weight_decay=args_weight_decay)

    # cudnn.benchmark = True

    best_loss = 10000000

    args_epochs = 300
    for epoch in range(args_epochs):
        print("============================== epoch:", epoch)

        disp_net.train()
        pose_exp_net.train()

        c_time = time.time()
        # 开始单个epoch

        running_loss = 0.0

        for loader_idx, (image_stack, image_stack_norm, intrinsic_mat,
                         _) in enumerate(train_dataloader):

            image_stack = [img.to(device) for img in image_stack]
            image_stack_norm = [img.to(device) for img in image_stack_norm]
            intrinsic_mat = intrinsic_mat.to(device)  # 1 4 3 3

            disp = {}
            depth = {}
            depth_upsampled = {}

            for seq_i in range(seq_length):
                multiscale_disps_i, _ = disp_net(image_stack[seq_i])
                # [1,1,128,416], [1,1,64,208],[1,1,32,104],[1,1,16,52]
                # print("mmultiscale_disps_i[0] size:", multiscale_disps_i[0].size() )
                # a = input("pasue ...")

                # if seq_i == 1:
                #     dd = multiscale_disps_i[0]
                #     dd = dd.detach().cpu().numpy()
                #     np.save( "./rst/" + str(loader_idx) + ".npy", dd)

                multiscale_depths_i = [1.0 / d for d in multiscale_disps_i]
                disp[seq_i] = multiscale_disps_i
                depth[seq_i] = multiscale_depths_i

                depth_upsampled[seq_i] = []

                for s in range(num_scales):
                    depth_upsampled[seq_i].append(
                        nn.functional.interpolate(multiscale_depths_i[s],
                                                  size=[128, 416],
                                                  mode='bilinear',
                                                  align_corners=True))

            egomotion = pose_exp_net(
                image_stack_norm[1],
                [image_stack_norm[0], image_stack_norm[2]])
            # torch.Size([1, 2, 6])

            # print("egomoiton size:", egomotion.size() )
            # a = input("pasue ...")

            # 开始build loss======================================
            middle_frame_index = (seq_length - 1) // 2  # 0 1 2 中间是 1

            # self.images is organized by ...[scale][B, h, w, seq_len * 3].
            images = [None for _ in range(num_scales)]

            # 先把图片缩放,为后续计算loss做准备
            for s in range(num_scales):
                height_s = int(128 / (2**s))
                width_s = int(416 / (2**s))

                images[s] = [
                    nn.functional.interpolate(x,
                                              size=[height_s, width_s],
                                              mode='bilinear',
                                              align_corners=True)
                    for x in image_stack
                ]

            smooth_loss = 0  # 计算各个尺度的 smooth_loss
            for s in range(num_scales):
                # Smoothness.
                for i in range(seq_length):
                    compute_minimum_loss = True
                    if not compute_minimum_loss or i == middle_frame_index:
                        disp_smoothing = disp[i][s]

                        mean_disp = torch.mean(disp_smoothing, (1, 2, 3), True)
                        # print("mean disp:", mean_disp)

                        disp_input = disp_smoothing / mean_disp

                        from loss_func import disp_smoothness
                        smooth_loss += (1.0 / (2**s)) * disp_smoothness(
                            disp_input, images[s][i])

                        # print("smooth loss success")
                        # a = input("pasue ...")

            # Following nested lists are organized by ...[scale][source-target].
            warped_image = [{} for _ in range(num_scales)]
            warp_mask = [{} for _ in range(num_scales)]
            warp_error = [{} for _ in range(num_scales)]
            ssim_error = [{} for _ in range(num_scales)]

            reconstr_loss = 0
            ssim_loss = 0

            for s in range(num_scales):

                for i in range(seq_length):
                    for j in range(seq_length):
                        if i == j:
                            continue

                        # When computing minimum loss, only consider the middle frame as target.
                        if compute_minimum_loss and j != middle_frame_index:
                            continue

                        exhaustive_mode = False
                        if (not compute_minimum_loss and not exhaustive_mode
                                and abs(i - j) != 1):
                            continue

                        depth_upsampling = True
                        selected_scale = 0 if depth_upsampling else s
                        source = images[selected_scale][i]
                        target = images[selected_scale][j]

                        if depth_upsampling:
                            target_depth = depth_upsampled[j][s]
                        else:
                            target_depth = depth[j][s]

                        key = '%d-%d' % (i, j)
                        # print("key:", key)

                        import util

                        # 这个时候传进来的egomotion的尺寸是 [batchsize, 2, 6]
                        egomotion_mat_i_j = util.get_transform_mat(
                            egomotion, i, j)
                        # print("egomotion_mat_i_j size:\n", egomotion_mat_i_j.size() ) ([1, 4, 4])

                        # print("egomotion_mat_i_j success!")
                        # a = input("pasue ...")

                        # print("intrinsic_mat size:", intrinsic_mat.size() )
                        warped_image[s][key], warp_mask[s][key] = \
                            util.inverse_warp(source,
                                              target_depth.squeeze(1),
                                              egomotion_mat_i_j[:, 0:3, :],
                                              intrinsic_mat[:, selected_scale, :, :]
                                              )

                        # print("inverse_warp success!")
                        # a = input("pasue ...")

                        # Reconstruction loss.
                        warp_error[s][key] = torch.abs(warped_image[s][key] -
                                                       target)
                        if not compute_minimum_loss:
                            reconstr_loss += torch.mean(warp_error[s][key] *
                                                        warp_mask[s][key])

                        # SSIM.
                        from loss_func import SSIM
                        ssim_error[s][key] = SSIM(warped_image[s][key], target)

                        # print("SSIM success!")
                        # a = input("pasue ...")

                        # TODO(rezama): This should be min_pool2d().
                        if not compute_minimum_loss:
                            # ssim_mask = slim.avg_pool2d(warp_mask[s][key], 3, 1, 'VALID')
                            ssim_mask = nn.AvgPool2d(3, 1)(warp_mask[s][key])
                            ssim_loss += torch.mean(ssim_error[s][key] *
                                                    ssim_mask)

            for s in range(num_scales):
                # If the minimum loss should be computed, the loss calculation has been
                # postponed until here.
                if compute_minimum_loss:
                    for frame_index in range(middle_frame_index):
                        key1 = '%d-%d' % (frame_index, middle_frame_index)
                        key2 = '%d-%d' % (seq_length - frame_index - 1,
                                          middle_frame_index)

                        # print('computing min error between %s and %s', key1, key2)

                        min_error = torch.min(warp_error[s][key1],
                                              warp_error[s][key2])
                        reconstr_loss += torch.mean(min_error)

                        # Also compute the minimum SSIM loss.
                        min_error_ssim = torch.min(ssim_error[s][key1],
                                                   ssim_error[s][key2])
                        ssim_loss += torch.mean(min_error_ssim)

            total_loss = 0.85 * reconstr_loss + 0.04 * smooth_loss + 0.15 * ssim_loss

            if loader_idx % 200 == 0:
                # if loader_idx % 10 == 0:
                print("idx: %4d reconstr: %.5f  smooth: %.5f  ssim: %.5f  total: %.5f" % \
                    (loader_idx, reconstr_loss, smooth_loss, ssim_loss, total_loss) )

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            running_loss += total_loss.item()

            #############单个epoch结束

        batch_size = 1

        if running_loss < best_loss:
            best_loss = running_loss
            print("\n")
            print("best loss:", best_loss / (len(train_loader) * batch_size))
            print("\n")

            # torch.save(disp_net.state_dict(),     './wt_tiny/disp_net_best.pth' )
            # torch.save(pose_exp_net.state_dict(), './wt_tiny/pose_exp_net_best.pth' )

        running_loss /= len(train_dataloader) / args.batchsize
        print('Epoch:', epoch, 'train_loss:', running_loss, 'time:',
              round(time.time() - c_time, 3), 's')
예제 #2
0
def main():

    print("train_size:", train_size)
    print("valid_size:", valid_size)

    seq_length = args.seq_length
    num_scales = 4

    torch.manual_seed(0)

    device = args.device
    
    disp_net = DispNetS().to(device) 
    disp_net.init_weights()
    
    pose_exp_net = PoseExpNet(nb_ref_imgs=seq_length - 1, output_exp=False).to(device)
    pose_exp_net.init_weights()
    
    args_lr = args.learning_rate
    optim_params = [
        {'params': disp_net.parameters(), 'lr': args_lr},
        {'params': pose_exp_net.parameters(), 'lr': args_lr}
    ]

    args_momentum = 0.9
    args_beta = 0.999
    args_weight_decay = 0

    optimizer = torch.optim.Adam(optim_params,
                                 betas=(args_momentum, args_beta),
                                 weight_decay = args_weight_decay
                                 )


    start_epoch = 0

    # continue_train = 1  # 是否是从 断点 开始训练  
    # if continue_train:  # 如果是继续训练,则载入之前的参数
    #     # 擦,这里还要考虑多个网络的权重如何保存和载入,还要改save和load,懒得写了
    #     # 还要搞个[model1, model2, ...] 之类的列表
    #     model, optimizer, epoch = load(args.previous_ckpt_path, model, optimizer)
    #     start_epoch = epoch + 1
    #     model = model.to(device)
    #     set_lr(optimizer, get_lr(start_epoch, args.learning_rate))
    #     print("\n")
    #     print("load previous checkpoint successfully!")
    #     print("start_epoch:", start_epoch)
    #     print("\n")
    # else:  # 如果不是,则是从头训练,什么都不用做
    #     model = model.to(device)
    #     print("\n")
    #     print("train from scrach!")
    #     print("start_epoch:", start_epoch)
    #     print("\n")


    # cudnn.benchmark = True
    
    best_loss = float('Inf')

    args_epochs = 300
    # for epoch in range(args_epochs):
    for epoch in range(start_epoch, args_epochs):  # 这样写是为了支持从断点开始继续训练
        print("============================== epoch:", epoch )
        
        disp_net.train()
        pose_exp_net.train()
        
        c_time = time.time()
        # 开始单个epoch

        running_loss = 0.0

        for loader_idx, (image_stack, image_stack_norm, intrinsic_mat, _) in enumerate(train_dataloader):
            
            image_stack = [img.to(device) for img in image_stack]
            image_stack_norm = [img.to(device) for img in image_stack_norm]
            intrinsic_mat = intrinsic_mat.to(device) # 1 4 3 3
            
            disp = {}
            depth = {}
            depth_upsampled = {}
            
            for seq_i in range(seq_length):
                multiscale_disps_i, _ = disp_net(image_stack[seq_i])
                # [1,1,128,416], [1,1,64,208],[1,1,32,104],[1,1,16,52]

                # if seq_i == 1:
                #     dd = multiscale_disps_i[0]
                #     dd = dd.detach().cpu().numpy()
                #     np.save( "./rst/" + str(loader_idx) + ".npy", dd)
                
                multiscale_depths_i = [1.0 / d for d in multiscale_disps_i]
                disp[seq_i] = multiscale_disps_i
                depth[seq_i] = multiscale_depths_i
                
                depth_upsampled[seq_i] = []
                
                for s in range(num_scales):
                    depth_upsampled[seq_i].append( nn.functional.interpolate(multiscale_depths_i[s],
                                   size=[128, 416], mode='bilinear', align_corners=True) )
                    
            egomotion = pose_exp_net(image_stack_norm[1], [ image_stack_norm[0], image_stack_norm[2] ])
            # torch.Size([1, 2, 6])

            # 开始build loss======================================
            from loss_func import calc_total_loss

            total_loss, reconstr_loss, smooth_loss, ssim_loss = \
            calc_total_loss(image_stack, disp, depth, depth_upsampled, egomotion, intrinsic_mat)
            # total loss 计算结束 ================================

            if loader_idx % (200/args.batchsize) == 0:
                print("idx: %4d reconstr: %.5f  smooth: %.5f  ssim: %.5f  total: %.5f" % \
                    (loader_idx, reconstr_loss, smooth_loss, ssim_loss, total_loss) )

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            running_loss += total_loss.item()

        #############单个epoch结束

        running_loss /= (train_size/args.batchsize)

        if running_loss < best_loss:
            best_loss = running_loss

            print("* best loss:", best_loss )

            torch.save(disp_net.state_dict(),     './disp_net_best.pth' )
            torch.save(pose_exp_net.state_dict(), './pose_exp_net_best.pth' )

        print ( 'Epoch:', epoch, 
                'train_loss:', running_loss,
                'time:', round(time.time() - c_time, 3), 's')
def main(T, AoI):
    import data_loader
    ds_path = './dataset'
    ds = data_loader.SequenceFolder()
    ds.init(ds_path, AoI, 0, 3, 4)
    train_size = int(
        0.9 * len(ds)
    )  # Split the full dataset into training dataset which has 90% of the full dataset images
    valid_size = len(
        ds) - train_size  # Validation dataset is the remaining images

    train_dataset, valid_dataset = random_split(ds,
                                                [train_size, valid_size])  #

    train_dataloader = DataLoader(train_dataset,
                                  batch_size=args.batchsize,
                                  shuffle=True,
                                  num_workers=4)

    valid_dataloader = DataLoader(valid_dataset,
                                  batch_size=args.batchsize,
                                  shuffle=True,
                                  num_workers=4)

    #print(len(train_dataloader))
    print("train_size:", train_size)
    print("valid_size:", valid_size)
    seq_length = args.seq_length
    num_scales = 4

    torch.manual_seed(0)

    device = args.device

    disp_net = DispNetS().to(device)
    disp_net.init_weights()

    pose_exp_net = PoseExpNet(nb_ref_imgs=seq_length - 1,
                              output_exp=False).to(device)
    pose_exp_net.init_weights()

    args_lr = args.learning_rate
    optim_params = [{
        'params': disp_net.parameters(),
        'lr': args_lr
    }, {
        'params': pose_exp_net.parameters(),
        'lr': args_lr
    }]

    args_momentum = 0.9
    args_beta = 0.999
    args_weight_decay = 0

    optimizer = torch.optim.Adam(optim_params,
                                 betas=(args_momentum, args_beta),
                                 weight_decay=args_weight_decay)

    start_epoch = 0

    # continue_train = 1  # Whether to learn from break point
    # if continue_train:  # If learn from break point

    #     model, optimizer, epoch = load(args.previous_ckpt_path, model, optimizer)
    #     start_epoch = epoch + 1
    #     model = model.to(device)
    #     set_lr(optimizer, get_lr(start_epoch, args.learning_rate))
    #     print("\n")
    #     print("load previous checkpoint successfully!")
    #     print("start_epoch:", start_epoch)
    #     print("\n")
    # else:  # Learn from beginning
    #     model = model.to(device)
    #     print("\n")
    #     print("train from scrach!")
    #     print("start_epoch:", start_epoch)
    #     print("\n")

    # cudnn.benchmark = True

    best_loss = float('Inf')
    optimal_epoch = 0
    optimal_valid_epoch = 0
    best_loss = float('Inf')
    valid_list = []
    loss_list = []
    valid_loss = 0.0
    best_valid_loss = float('Inf')
    args_epochs = 100

    print("============================== Trainning start")
    # for epoch in range(args_epochs):
    for epoch in range(start_epoch, args_epochs):  #

        disp_net.train()
        pose_exp_net.train()

        c_time = time.time()
        # Start a epoch

        running_loss = 0.0

        for loader_idx, (image_stack, image_stack_norm, intrinsic_mat,
                         _) in enumerate(train_dataloader):
            #pdb.set_trace()
            image_stack = [img.to(device) for img in image_stack]
            image_stack_norm = [img.to(device) for img in image_stack_norm]
            intrinsic_mat = intrinsic_mat.to(device)  # 1 4 3 3
            disp = {}
            depth = {}
            depth_upsampled = {}

            for seq_i in range(seq_length):
                multiscale_disps_i, _ = disp_net(image_stack[seq_i])
                # [1,1,128,416], [1,1,64,208],[1,1,32,104],[1,1,16,52]

                # if seq_i == 1:
                #     dd = multiscale_disps_i[0]
                #     dd = dd.detach().cpu().numpy()
                #     np.save( "./rst/" + str(loader_idx) + ".npy", dd)

                multiscale_depths_i = [1.0 / d for d in multiscale_disps_i]
                disp[seq_i] = multiscale_disps_i
                depth[seq_i] = multiscale_depths_i

                depth_upsampled[seq_i] = []

                for s in range(num_scales):
                    depth_upsampled[seq_i].append(
                        nn.functional.interpolate(multiscale_depths_i[s],
                                                  size=[128, 416],
                                                  mode='bilinear',
                                                  align_corners=True))

            egomotion = pose_exp_net(
                image_stack_norm[2],
                [image_stack_norm[0], image_stack_norm[1]
                 ])  #change from midle to last
            # torch.Size([1, 2, 6])

            # build loss======================================
            from loss_func import calc_total_loss

            total_loss, reconstr_loss, smooth_loss, ssim_loss = \
            calc_total_loss(image_stack, disp, depth, depth_upsampled, egomotion, intrinsic_mat)
            # total loss  ================================

            if loader_idx % (200 / args.batchsize) == 0:
                print("idx: %4d reconstr: %.5f  smooth: %.5f  ssim: %.5f  total: %.5f" % \
                    (loader_idx, reconstr_loss, smooth_loss, ssim_loss, total_loss) )

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            running_loss += total_loss.item()
        running_loss /= (train_size / args.batchsize)
        loss_list.append(running_loss)
        if running_loss < best_loss:
            best_loss = running_loss
            optimal_epoch = epoch
            #print("* best loss:", best_loss )

            torch.save(disp_net.state_dict(), './disp_net_best.pth')
            torch.save(pose_exp_net.state_dict(), './pose_exp_net_best.pth')

        for loader_idx_val, (image_stack_val, image_stack_norm_val,
                             intrinsic_mat_val,
                             _) in enumerate(valid_dataloader):

            image_stack_val = [img.to(device) for img in image_stack_val]
            image_stack_norm_val = [
                img.to(device) for img in image_stack_norm_val
            ]
            intrinsic_mat_val = intrinsic_mat_val.to(device)  # 1 4 3 3

            disp_val = {}
            depth_val = {}
            depth_upsampled_val = {}

            for seq_i in range(seq_length):
                multiscale_disps_val_i, _ = disp_net(image_stack_val[seq_i])
                # [1,1,128,416], [1,1,64,208],[1,1,32,104],[1,1,16,52]
                multiscale_depths_val_i = [
                    1.0 / d for d in multiscale_disps_val_i
                ]
                disp_val[seq_i] = multiscale_disps_val_i
                depth_val[seq_i] = multiscale_depths_val_i
                depth_upsampled_val[seq_i] = []

                for s in range(num_scales):
                    depth_upsampled_val[seq_i].append(
                        nn.functional.interpolate(multiscale_depths_val_i[s],
                                                  size=[128, 416],
                                                  mode='bilinear',
                                                  align_corners=True))

            egomotion_val = pose_exp_net(
                image_stack_norm_val[2],
                [image_stack_norm_val[0], image_stack_norm_val[1]
                 ])  #change from middle to last
            # torch.Size([1, 2, 6])

            # build loss======================================
            from loss_func import calc_total_loss

            total_loss_val, reconstr_loss_val, smooth_loss_val, ssim_loss_val = \
            calc_total_loss(image_stack_val, disp_val, depth_val, depth_upsampled_val, egomotion_val, intrinsic_mat_val)
            #pdb.set_trace()
            # total loss  ================================
            valid_loss += total_loss_val.item()
        valid_loss /= (valid_size / args.batchsize)
        valid_list.append(valid_loss)
        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            optimal_valid_epoch = epoch
        #############epoch ends
    print("============================== Training End")
    print('time:', round(time.time() - c_time, 3), 's', 'Best training loss:',
          best_loss, 'Optimal training epoch is:', optimal_epoch,
          'Best validation loss:', best_valid_loss,
          'Optimal validation epoch is:', optimal_valid_epoch, 'AOI', AoI)
    valid_list.sort()
    '''
    x = np.arange(0,args_epochs)
    plt.plot(x,loss_list,'r--',label='Training Loss') 
    plt.plot(x,valid_list,'g--',label='Validation Loss')
    plt.title('Training Loss vs Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')   
    plt.ylim(0,1)
    plt.xlim(0,110)
    plt.xticks(range(len(loss_list)))
    x_major_locator=MultipleLocator(5)
    ax=plt.gca()
    ax.xaxis.set_major_locator(x_major_locator)
    plt.legend()
    plt.show()
    '''

    fw = open("loss_new.txt", 'a')
    fw.write('{} {} {}\n'.format(valid_list[0], best_loss, T * AoI))

    fw.close()
    reload(data_loader)