Example #1
0
def train(args):

    # Setup Augmentations
    data_aug = Compose([RandomRotate(10),
                        RandomHorizontallyFlip()])
    loss_rec=[]
    best_error=2
    # Setup Dataloader
    data_loader = get_loader(args.dataset)
    data_path = get_data_path(args.dataset)
    t_loader = data_loader(data_path, is_transform=True,
                           split='train_region', img_size=(args.img_rows, args.img_cols),task='region')
    v_loader = data_loader(data_path, is_transform=True,
                           split='test_region', img_size=(args.img_rows, args.img_cols),task='region')

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(
        t_loader, batch_size=args.batch_size, num_workers=4, shuffle=True)
    valloader = data.DataLoader(
        v_loader, batch_size=args.batch_size, num_workers=4)

    # Setup Metrics
    running_metrics = runningScore(n_classes)

    # Setup visdom for visualization
    if args.visdom:
        vis = visdom.Visdom()
        old_window = vis.line(X=torch.zeros((1,)).cpu(),
                               Y=torch.zeros((1)).cpu(),
                               opts=dict(xlabel='minibatches',
                                         ylabel='Loss',
                                         title='Trained Loss',
                                         legend=['Loss']))
        loss_window = vis.line(X=torch.zeros((1,)).cpu(),
                               Y=torch.zeros((1)).cpu(),
                               opts=dict(xlabel='minibatches',
                                         ylabel='Loss',
                                         title='Training Loss',
                                         legend=['Loss']))
        pre_window = vis.image(
            np.random.rand(480, 640),
            opts=dict(title='predict!', caption='predict.'),
        )
        ground_window = vis.image(
            np.random.rand(480, 640),
            opts=dict(title='ground!', caption='ground.'),
        )
    # Setup Model
    model = get_model(args.arch)
    model = torch.nn.DataParallel(
        model, device_ids=range(torch.cuda.device_count()))
    #model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
    model.cuda()

    # Check if model has custom optimizer / loss
    # modify to adam, modify the learning rate
    if hasattr(model.module, 'optimizer'):
        optimizer = model.module.optimizer
    else:
        # optimizer = torch.optim.Adam(
        #     model.parameters(), lr=args.l_rate,weight_decay=5e-4,betas=(0.9,0.999))
        optimizer = torch.optim.SGD(
            model.parameters(), lr=args.l_rate,momentum=0.99, weight_decay=5e-4)
    if hasattr(model.module, 'loss'):
        print('Using custom loss')
        loss_fn = model.module.loss
    else:
        loss_fn = region_log
    trained=0
    scale=100

    if args.resume is not None:
        if os.path.isfile(args.resume):
            print("Loading model and optimizer from checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            #model_dict=model.state_dict()  
            #opt=torch.load('/home/lidong/Documents/RSDEN/RSDEN/exp1/l2/sgd/log/83/rsnet_nyu_best_model.pkl')
            model.load_state_dict(checkpoint['model_state'])
            optimizer.load_state_dict(checkpoint['optimizer_state'])
            #opt=None
            print("Loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
            trained=checkpoint['epoch']
            best_error=checkpoint['error']
            #best_error=5
            #print('load success!')
            loss_rec=np.load('/home/lidong/Documents/RSDEN/RSDEN/loss.npy')
            loss_rec=list(loss_rec)
            loss_rec=loss_rec[:816*trained]
            # for i in range(300):
            #     loss_rec[i][1]=loss_rec[i+300][1]
            for l in range(int(len(loss_rec)/816)):
                if args.visdom:
                    #print(loss_rec[l])
                    vis.line(
                        X=torch.ones(1).cpu() * loss_rec[l*816][0],
                        Y=np.mean(np.array(loss_rec[l*816:(l+1)*816])[:,1])*torch.ones(1).cpu(),
                        win=old_window,
                        update='append')
            
    else:

        print("No checkpoint found at '{}'".format(args.resume))
        print('Initialize from resnet34!')
        resnet34=torch.load('/home/lidong/Documents/RSDEN/RSDEN/resnet34-333f7ec4.pth')
        model_dict=model.state_dict()            
        pre_dict={k: v for k, v in resnet34.items() if k in model_dict}
        model_dict.update(pre_dict)
        model.load_state_dict(model_dict)
        print('load success!')
        best_error=1
        trained=0


    #best_error=5
    # it should be range(checkpoint[''epoch],args.n_epoch)
    for epoch in range(trained, args.n_epoch):
    #for epoch in range(0, args.n_epoch):
        
        #trained
        print('training!')
        model.train()
        for i, (images, labels,segments) in enumerate(trainloader):
            images = Variable(images.cuda())
            labels = Variable(labels.cuda())
            segments = Variable(segments.cuda())
            optimizer.zero_grad()
            outputs = model(images)
            #outputs=outputs
            loss = loss_fn(input=outputs, target=labels,instance=segments)
            # print('training:'+str(i)+':learning_rate'+str(loss.data.cpu().numpy()))
            loss.backward()
            optimizer.step()
            # print(torch.Tensor([loss.data[0]]).unsqueeze(0).cpu())
            #print(loss.item()*torch.ones(1).cpu())
            #nyu2_train:246,nyu2_all:816
            if args.visdom:
                vis.line(
                    X=torch.ones(1).cpu() * i+torch.ones(1).cpu() *(epoch-trained)*816,
                    Y=loss.item()*torch.ones(1).cpu(),
                    win=loss_window,
                    update='append')
                pre = outputs.data.cpu().numpy().astype('float32')
                pre = pre[0, :, :, :]
                #pre = np.argmax(pre, 0)
                pre = (np.reshape(pre, [480, 640]).astype('float32')-np.min(pre))/(np.max(pre)-np.min(pre))
                #pre = pre/np.max(pre)
                # print(type(pre[0,0]))
                vis.image(
                    pre,
                    opts=dict(title='predict!', caption='predict.'),
                    win=pre_window,
                )
                ground=labels.data.cpu().numpy().astype('float32')
                #print(ground.shape)
                ground = ground[0, :, :]
                ground = (np.reshape(ground, [480, 640]).astype('float32')-np.min(ground))/(np.max(ground)-np.min(ground))
                vis.image(
                    ground,
                    opts=dict(title='ground!', caption='ground.'),
                    win=ground_window,
                )
            
            loss_rec.append([i+epoch*816,torch.Tensor([loss.item()]).unsqueeze(0).cpu()])
            print("data [%d/816/%d/%d] Loss: %.4f" % (i, epoch, args.n_epoch,loss.item()))
        
        if epoch>50:
            check=3
        else:
            check=5
        if epoch>70:
            check=2
        if epoch>85:
            check=1                 
        if epoch%check==0:  
            print('testing!')
            model.train()
            error_lin=[]
            error_log=[]
            error_va=[]
            error_rate=[]
            error_absrd=[]
            error_squrd=[]
            thre1=[]
            thre2=[]
            thre3=[]
            variance=[]
            for i_val, (images_val, labels_val,segments) in tqdm(enumerate(valloader)):
                print(r'\n')
                images_val = Variable(images_val.cuda(), requires_grad=False)
                labels_val = Variable(labels_val.cuda(), requires_grad=False)
                segments = Variable(segments.cuda(), requires_grad=False)
                with torch.no_grad():
                    outputs = model(images_val)
                    pred = outputs.data.cpu().numpy()
                    gt = labels_val.data.cpu().numpy()
                    instance = segments.data.cpu().numpy()
                    ones=np.ones((gt.shape))
                    zeros=np.zeros((gt.shape))
                    pred=np.reshape(pred,(gt.shape))
                    instance=np.reshape(instance,(gt.shape))
                    #gt=np.reshape(gt,[4,480,640])
                    # dis=np.square(gt-pred)
                    # error_lin.append(np.sqrt(np.mean(dis)))
                    # dis=np.square(np.log(gt)-np.log(pred))
                    # error_log.append(np.sqrt(np.mean(dis)))
                    var=0
                    linear=0
                    log_dis=0
                    for i in range(1,int(np.max(instance)+1)):
                        pre_region=np.where(instance==i,pred,0)
                        dis=np.where(instance==i,np.abs(gt-pred),0)
                        num=np.sum(np.where(instance==i,1,0))
                        m=np.sum(pre_region)/num
                        pre_region=np.where(instance==i,pred-m,0)
                        pre_region=np.sum(np.square(pre_region))/num
                        log_region=np.where(instance==i,np.abs(np.log(gt+1e-6)-np.log(pred+1e-6)),0)
                        var+=pre_region
                        linear+=np.sum(dis)/num
                        log_dis+=np.sum(log_region)/num
                    error_log.append(log_dis/np.max(instance))
                    error_lin.append(linear/np.max(instance))
                    variance.append(var/np.max(instance))    
                    print("error_lin=%.4f,error_log=%.4f,variance=%.4f"%(
                        error_lin[i_val],
                        error_log[i_val],
                        variance[i_val]))                   
                    # alpha=np.mean(np.log(gt)-np.log(pred))
                    # dis=np.square(np.log(pred)-np.log(gt)+alpha)
                    # error_va.append(np.mean(dis)/2)
                    # dis=np.mean(np.abs(gt-pred))/gt
                    # error_absrd.append(np.mean(dis))
                    # dis=np.square(gt-pred)/gt
                    # error_squrd.append(np.mean(dis))
                    # thelt=np.where(pred/gt>gt/pred,pred/gt,gt/pred)
                    # thres1=1.25

                    # thre1.append(np.mean(np.where(thelt<thres1,ones,zeros)))
                    # thre2.append(np.mean(np.where(thelt<thres1*thres1,ones,zeros)))
                    # thre3.append(np.mean(np.where(thelt<thres1*thres1*thres1,ones,zeros)))
                    # #a=thre1[i_val]
                    # #error_rate.append(np.mean(np.where(dis<0.6,ones,zeros)))
                    # print("error_lin=%.4f,error_log=%.4f,error_va=%.4f,error_absrd=%.4f,error_squrd=%.4f,thre1=%.4f,thre2=%.4f,thre3=%.4f"%(
                    #     error_lin[i_val],
                    #     error_log[i_val],
                    #     error_va[i_val],
                    #     error_absrd[i_val],
                    #     error_squrd[i_val],
                    #     thre1[i_val],
                    #     thre2[i_val],
                    #     thre3[i_val]))
            error=np.mean(error_lin)
            variance=np.mean(variance)
            #error_rate=np.mean(error_rate)
            print("error=%.4f,variance=%.4f"%(error,variance))

            if error<= best_error:
                best_error = error
                state = {'epoch': epoch+1,
                         'model_state': model.state_dict(),
                         'optimizer_state': optimizer.state_dict(),
                         'error': error,}
                torch.save(state, "{}_{}_best_model.pkl".format(
                    args.arch, args.dataset))
                print('save success')
            np.save('/home/lidong/Documents/RSDEN/RSDEN//loss.npy',loss_rec)
        if epoch%5==0:
            #best_error = error
            state = {'epoch': epoch+1,
                     'model_state': model.state_dict(),
                     'optimizer_state': optimizer.state_dict(), 
                     'error': error,}
            torch.save(state, "{}_{}_{}_model.pkl".format(
                args.arch, args.dataset,str(epoch)))
            print('save success')
Example #2
0
def train(args):
    scale = 2
    torch.backends.cudnn.benchmark = True
    # Setup Augmentations
    data_aug = Compose([RandomRotate(10), RandomHorizontallyFlip()])
    loss_rec = []
    best_error = 2
    # Setup Dataloader
    data_loader = get_loader(args.dataset)
    data_path = get_data_path(args.dataset)
    t_loader = data_loader(data_path,
                           is_transform=True,
                           split='train',
                           img_size=(args.img_rows, args.img_cols),
                           task='region')
    v_loader = data_loader(data_path,
                           is_transform=True,
                           split='test',
                           img_size=(args.img_rows, args.img_cols),
                           task='region')

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(t_loader,
                                  batch_size=args.batch_size,
                                  num_workers=4,
                                  shuffle=True)
    valloader = data.DataLoader(v_loader,
                                batch_size=args.batch_size,
                                num_workers=4,
                                shuffle=False)

    # Setup Metrics
    running_metrics = runningScore(n_classes)

    # Setup visdom for visualization
    if args.visdom:
        vis = visdom.Visdom(env='nyu2_coarse')

        depth_window = vis.image(
            np.random.rand(480 // scale, 640 // scale),
            opts=dict(title='depth!', caption='depth.'),
        )
        accurate_window = vis.image(
            np.random.rand(480 // scale, 640 // scale),
            opts=dict(title='accurate!', caption='accurate.'),
        )

        ground_window = vis.image(
            np.random.rand(480 // scale, 640 // scale),
            opts=dict(title='ground!', caption='ground.'),
        )
        image_window = vis.image(
            np.random.rand(480 // scale, 640 // scale),
            opts=dict(title='img!', caption='img.'),
        )
        loss_window = vis.line(X=torch.zeros((1, )).cpu(),
                               Y=torch.zeros((1)).cpu(),
                               opts=dict(xlabel='minibatches',
                                         ylabel='Loss',
                                         title='Training Loss',
                                         legend=['Loss']))
        lin_window = vis.line(X=torch.zeros((1, )).cpu(),
                              Y=torch.zeros((1)).cpu(),
                              opts=dict(xlabel='minibatches',
                                        ylabel='error',
                                        title='linear Loss',
                                        legend=['linear error']))
        error_window = vis.line(X=torch.zeros((1, )).cpu(),
                                Y=torch.zeros((1)).cpu(),
                                opts=dict(xlabel='minibatches',
                                          ylabel='error',
                                          title='error',
                                          legend=['Error']))
    # Setup Model
    model = get_model(args.arch)
    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))
    #model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
    model.cuda()

    # Check if model has custom optimizer / loss
    # modify to adam, modify the learning rate
    if hasattr(model.module, 'optimizer'):
        optimizer = model.module.optimizer
    else:
        # optimizer = torch.optim.Adam(
        #     model.parameters(), lr=args.l_rate,betas=(0.9,0.999),amsgrad=True)
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.l_rate,
                                    momentum=0.90)
    # scheduler=torch.optim.lr_scheduler.StepLR(optimizer,step_size=30,gamma=0.5)
    if hasattr(model.module, 'loss'):
        print('Using custom loss')
        loss_fn = model.module.loss
    else:
        loss_fn = log_loss
    trained = 0
    #scale=100

    if args.resume is not None:
        if os.path.isfile(args.resume):
            print("Loading model and optimizer from checkpoint '{}'".format(
                args.resume))
            checkpoint = torch.load(args.resume, map_location='cpu')
            #model_dict=model.state_dict()
            #opt=torch.load('/home/lidong/Documents/RSDEN/RSDEN/exp1/l2/sgd/log/83/rsnet_nyu_best_model.pkl')
            model.load_state_dict(checkpoint['model_state'])
            optimizer.load_state_dict(checkpoint['optimizer_state'])
            #opt=None
            print("Loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            trained = checkpoint['epoch']
            best_error = checkpoint['error']
            print(best_error)
            print(trained)
            loss_rec = np.load('/home/lidong/Documents/RSCFN/loss.npy')
            loss_rec = list(loss_rec)
            loss_rec = loss_rec[:199 * trained]
            test = 0
            #exit()
            trained = 0

    else:
        best_error = 100
        best_error_r = 100
        trained = 0
        print('random initialize')

        print("No checkpoint found at '{}'".format(args.resume))
        print('Initialize from rsn!')
        rsn = torch.load(
            '/home/lidong/Documents/RSCFN/rsn_cluster_nyu2_124_1.103912coarse_best_model.pkl',
            map_location='cpu')
        model_dict = model.state_dict()
        #print(model_dict)
        pre_dict = {
            k: v
            for k, v in rsn['model_state'].items()
            if k in model_dict and rsn['model_state'].items()
        }
        #pre_dict={k: v for k, v in rsn.items() if k in model_dict and rsn.items()}
        #print(pre_dict)
        key = []
        for k, v in pre_dict.items():
            if v.shape != model_dict[k].shape:
                key.append(k)
        for k in key:
            pre_dict.pop(k)
        # #print(pre_dict)
        model_dict.update(pre_dict)
        model.load_state_dict(model_dict)
        #optimizer.load_state_dict(rsn['optimizer_state'])
        trained = rsn['epoch']
        best_error = rsn['error']
        print('load success!')
        print(best_error)
        best_error += 1
        #del rsn
        test = 0
        # loss_rec=np.load('/home/lidong/Documents/RSCFN/loss.npy')
        # loss_rec=list(loss_rec)
        # loss_rec=loss_rec[:199*trained]
        #exit()

    # it should be range(checkpoint[''epoch],args.n_epoch)
    for epoch in range(trained, args.n_epoch):
        #for epoch in range(0, args.n_epoch):
        #scheduler.step()
        #trained
        print('training!')
        model.train()

        for i, (images, labels, regions, segments,
                image) in enumerate(trainloader):
            #break
            images = Variable(images.cuda())
            labels = Variable(labels.cuda())
            segments = Variable(segments.cuda())
            regions = Variable(regions.cuda())

            optimizer.zero_grad()

            #depth,feature,loss_var,loss_dis,loss_reg = model(images,segments)
            #depth,loss_var,loss_dis,loss_reg = model(images,segments)
            #depth,masks,loss_var,loss_dis,loss_reg = model(images,segments,1,'train')
            depth, accurate = model(images, regions, 1, 'eval')
            print('depth', torch.mean(depth).item())
            print('accurate', torch.mean(accurate).item())
            print('ground', torch.mean(labels).item())
            loss_d = log_loss(depth, labels)
            #loss_i=berhu_log(intial,labels)
            loss_a = berhu_log(accurate, labels)
            #loss_d=log_loss(depth,labels)
            #loss=log_loss(depth, labels)
            loss = loss_d
            #loss=torch.sum(loss_var)+torch.sum(loss_dis)+0.001*torch.sum(loss_reg)
            #loss=loss/4+loss_d
            #loss/=feature.shape[0]
            # depth = model(images,segments)
            # loss_d=berhu(depth,labels)
            lin = torch.sqrt(torch.mean(torch.pow(accurate - labels, 2)))
            # loss=loss_d
            if loss.item() > 10:
                loss = loss / 10
            loss.backward()
            optimizer.step()
            #print(torch.mean(depth).item())
            if args.visdom:
                with torch.no_grad():

                    vis.line(X=torch.ones(1).cpu() * i + torch.ones(1).cpu() *
                             (epoch - trained) * 199,
                             Y=loss.item() * torch.ones(1).cpu(),
                             win=loss_window,
                             update='append')
                    vis.line(X=torch.ones(1).cpu() * i + torch.ones(1).cpu() *
                             (epoch - trained) * 199,
                             Y=lin.item() * torch.ones(1).cpu(),
                             win=lin_window,
                             update='append')
                    ground = labels.data.cpu().numpy().astype('float32')
                    ground = ground[0, :, :]
                    ground = (np.reshape(ground, [
                        480 // scale, 640 // scale
                    ]).astype('float32')) / (np.max(ground) + 0.001)
                    vis.image(
                        ground,
                        opts=dict(title='ground!', caption='ground.'),
                        win=ground_window,
                    )
                    accurate = accurate.data.cpu().numpy().astype('float32')
                    accurate = accurate[0, ...]
                    accurate = np.abs(
                        (np.reshape(accurate, [480 // scale, 640 //
                                               scale]).astype('float32')) /
                        (np.max(accurate) + 0.001) - ground)
                    vis.image(
                        accurate,
                        opts=dict(title='accurate!', caption='accurate.'),
                        win=accurate_window,
                    )

                    depth = depth.data.cpu().numpy().astype('float32')
                    depth = depth[0, :, :, :]
                    #depth=np.where(depth>np.max(ground),np.max(ground),depth)
                    depth = (np.reshape(
                        depth, [480 // scale, 640 // scale
                                ]).astype('float32')) / (np.max(depth) + 0.001)
                    vis.image(
                        depth,
                        opts=dict(title='depth!', caption='depth.'),
                        win=depth_window,
                    )
                    image = image.data.cpu().numpy().astype('float32')
                    image = image[0, ...]
                    #image=image[0,...]
                    #print(image.shape,np.min(image))
                    image = np.reshape(
                        image,
                        [3, 480 // scale, 640 // scale]).astype('float32')
                    vis.image(
                        image,
                        opts=dict(title='image!', caption='image.'),
                        win=image_window,
                    )
            loss_rec.append([
                i + epoch * 199,
                torch.Tensor([loss.item()]).unsqueeze(0).cpu()
            ])

            print(
                "data [%d/199/%d/%d] Loss: %.4f d: %.4f loss_d:%.4f loss_a:%.4f"
                % (i, epoch, args.n_epoch, loss.item(), lin.item(),
                   loss_d.item(), loss_a.item()))
            # print("data [%d/199/%d/%d] Loss: %.4f linear: %.4f " % (i, epoch, args.n_epoch,loss.item(),lin.item()
            #                    ))

        # state = {'epoch': epoch+1,
        #          'model_state': model.state_dict(),
        #          'optimizer_state': optimizer.state_dict(),
        #          }
        # torch.save(state, "{}_{}_{}_pretrain_best_model.pkl".format(
        #     args.arch, args.dataset,str(epoch)))
        # print('save success')
        # np.save('/home/lidong/Documents/RSCFN/loss.npy',loss_rec)
        if epoch > 50:
            check = 3
            #scheduler=torch.optim.lr_scheduler.StepLR(optimizer,step_size=30,gamma=0.5)
        else:
            check = 5
            #scheduler=torch.optim.lr_scheduler.StepLR(optimizer,step_size=15,gamma=1)
        if epoch > 70:
            check = 2
            #scheduler=torch.optim.lr_scheduler.StepLR(optimizer,step_size=15,gamma=0.25)
        if epoch > 90:
            check = 1
            #scheduler=torch.optim.lr_scheduler.StepLR(optimizer,step_size=30,gamma=0.1)
        # check=1
        #epoch=3
        if epoch % check == 0:

            print('testing!')
            model.train()
            loss_ave = []
            loss_d_ave = []
            loss_lin_ave = []
            loss_r_ave = []
            for i_val, (images_val, labels_val, regions, segments,
                        image) in tqdm(enumerate(valloader)):
                #print(r'\n')
                images_val = Variable(images_val.cuda(), requires_grad=False)
                labels_val = Variable(labels_val.cuda(), requires_grad=False)
                segments_val = Variable(segments.cuda(), requires_grad=False)
                regions_val = Variable(regions.cuda(), requires_grad=False)
                with torch.no_grad():
                    #depth,loss_var,loss_dis,loss_reg = model(images_val,segments_val,1,'test')
                    depth, accurate = model(images_val, regions_val, 1, 'eval')
                    # loss_d=berhu(depth,labels_val)
                    # loss=torch.sum(loss_var)+torch.sum(loss_dis)+0.001*torch.sum(loss_reg)
                    # loss=loss+loss_d
                    lin = torch.sqrt(
                        torch.mean(torch.pow(accurate - labels_val, 2)))
                    loss_ave.append(lin.data.cpu().numpy())
                    #print('error:')
                    #print(loss_ave[-1])
                    print("error=%.4f" % (lin.item()))
                    # print("loss_d=%.4f loss_var=%.4f loss_dis=%.4f loss_reg=%.4f"%(torch.sum(lin).item()/4,torch.sum(loss_var).item()/4, \
                    #             torch.sum(loss_dis).item()/4,0.001*torch.sum(loss_reg).item()/4))
                if args.visdom:
                    vis.line(X=torch.ones(1).cpu() * i_val +
                             torch.ones(1).cpu() * test * 163,
                             Y=lin.item() * torch.ones(1).cpu(),
                             win=error_window,
                             update='append')
                    ground = labels_val.data.cpu().numpy().astype('float32')
                    ground = ground[0, :, :]
                    ground = (np.reshape(ground, [
                        480 // scale, 640 // scale
                    ]).astype('float32')) / (np.max(ground) + 0.001)
                    vis.image(
                        ground,
                        opts=dict(title='ground!', caption='ground.'),
                        win=ground_window,
                    )
                    accurate = accurate.data.cpu().numpy().astype('float32')
                    accurate = accurate[0, ...]
                    accurate = np.abs(
                        (np.reshape(accurate, [480 // scale, 640 //
                                               scale]).astype('float32')) -
                        ground)

                    accurate = accurate / (np.max(accurate) + 0.001)
                    vis.image(
                        accurate,
                        opts=dict(title='accurate!', caption='accurate.'),
                        win=accurate_window,
                    )

                    depth = depth.data.cpu().numpy().astype('float32')
                    depth = depth[0, :, :, :]
                    #depth=np.where(depth>np.max(ground),np.max(ground),depth)
                    depth = (np.reshape(
                        depth, [480 // scale, 640 // scale
                                ]).astype('float32')) / (np.max(depth) + 0.001)
                    vis.image(
                        depth,
                        opts=dict(title='depth!', caption='depth.'),
                        win=depth_window,
                    )
                    image = image.data.cpu().numpy().astype('float32')
                    image = image[0, ...]
                    #image=image[0,...]
                    #print(image.shape,np.min(image))
                    image = np.reshape(
                        image,
                        [3, 480 // scale, 640 // scale]).astype('float32')
                    vis.image(
                        image,
                        opts=dict(title='image!', caption='image.'),
                        win=image_window,
                    )
            error = np.mean(loss_ave)
            #error_d=np.mean(loss_d_ave)
            #error_lin=np.mean(loss_lin_ave)
            #error_rate=np.mean(error_rate)
            print("error_r=%.4f" % (error))
            test += 1

            if error <= best_error:
                best_error = error
                state = {
                    'epoch': epoch + 1,
                    'model_state': model.state_dict(),
                    'optimizer_state': optimizer.state_dict(),
                    'error': error,
                }
                torch.save(
                    state, "{}_{}_{}_{}coarse_best_model.pkl".format(
                        args.arch, args.dataset, str(epoch), str(error)))
                print('save success')
            np.save('/home/lidong/Documents/RSCFN/loss.npy', loss_rec)

        if epoch % 10 == 0:
            #best_error = error
            state = {
                'epoch': epoch + 1,
                'model_state': model.state_dict(),
                'optimizer_state': optimizer.state_dict(),
                'error': error,
            }
            torch.save(
                state,
                "{}_{}_{}_coarse_model.pkl".format(args.arch, args.dataset,
                                                   str(epoch)))
            print('save success')
Example #3
0
def train(args):

    # Setup Augmentations
    data_aug = Compose([RandomRotate(10),
                        RandomHorizontallyFlip()])
    loss_rec=[]
    best_error=2
    # Setup Dataloader
    data_loader = get_loader(args.dataset)
    data_path = get_data_path(args.dataset)
    t_loader = data_loader(data_path, is_transform=True,
                           split='train_region', img_size=(args.img_rows, args.img_cols))
    v_loader = data_loader(data_path, is_transform=True,
                           split='test_region', img_size=(args.img_rows, args.img_cols))

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(
        t_loader, batch_size=args.batch_size, num_workers=2, shuffle=True)
    valloader = data.DataLoader(
        v_loader, batch_size=args.batch_size, num_workers=2)

    # Setup Metrics
    running_metrics = runningScore(n_classes)

    # Setup visdom for visualization
    if args.visdom:
        vis = visdom.Visdom()
        # old_window = vis.line(X=torch.zeros((1,)).cpu(),
        #                        Y=torch.zeros((1)).cpu(),
        #                        opts=dict(xlabel='minibatches',
        #                                  ylabel='Loss',
        #                                  title='Trained Loss',
        #                                  legend=['Loss']))
        loss_window1 = vis.line(X=torch.zeros((1,)).cpu(),
                               Y=torch.zeros((1)).cpu(),
                               opts=dict(xlabel='minibatches',
                                         ylabel='Loss',
                                         title='Training Loss1',
                                         legend=['Loss1']))
        loss_window2 = vis.line(X=torch.zeros((1,)).cpu(),
                               Y=torch.zeros((1)).cpu(),
                               opts=dict(xlabel='minibatches',
                                         ylabel='Loss',
                                         title='Training Loss2',
                                         legend=['Loss']))
        loss_window3 = vis.line(X=torch.zeros((1,)).cpu(),
                               Y=torch.zeros((1)).cpu(),
                               opts=dict(xlabel='minibatches',
                                         ylabel='Loss',
                                         title='Training Loss3',
                                         legend=['Loss3']))                                                 
        pre_window1 = vis.image(
            np.random.rand(480, 640),
            opts=dict(title='predict1!', caption='predict1.'),
        )
        pre_window2 = vis.image(
            np.random.rand(480, 640),
            opts=dict(title='predict2!', caption='predict2.'),
        )
        pre_window3 = vis.image(
            np.random.rand(480, 640),
            opts=dict(title='predict3!', caption='predict3.'),
        )

        ground_window = vis.image(
            np.random.rand(480, 640),
            opts=dict(title='ground!', caption='ground.'),
        )
    cuda0=torch.device('cuda:0')
    cuda1=torch.device('cuda:1')
    cuda2=torch.device('cuda:2')
    cuda3=torch.device('cuda:3')
    # Setup Model
    rsnet = get_model('rsnet')
    rsnet = torch.nn.DataParallel(rsnet, device_ids=[0,1])
    rsnet.cuda(cuda0)
    drnet=get_model('drnet')
    drnet = torch.nn.DataParallel(drnet, device_ids=[2,3])
    drnet.cuda(cuda2)
    parameters=list(rsnet.parameters())+list(drnet.parameters())
    # Check if model has custom optimizer / loss
    # modify to adam, modify the learning rate
    if hasattr(drnet.module, 'optimizer'):
        optimizer = drnet.module.optimizer
    else:
        # optimizer = torch.optim.Adam(
        #     model.parameters(), lr=args.l_rate,weight_decay=5e-4,betas=(0.9,0.999))
        optimizer = torch.optim.SGD(
            parameters, lr=args.l_rate,momentum=0.99, weight_decay=5e-4)
    if hasattr(rsnet.module, 'loss'):
        print('Using custom loss')
        loss_fn = rsnet.module.loss
    else:
        loss_fn = l1_r
    trained=0
    scale=100

    if args.resume is not None:
        if os.path.isfile(args.resume):
            print("Loading model and optimizer from checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            #model_dict=model.state_dict()  
            #opt=torch.load('/home/lidong/Documents/RSDEN/RSDEN/exp1/l2/sgd/log/83/rsnet_nyu_best_model.pkl')
            model.load_state_dict(checkpoint['model_state'])
            #optimizer.load_state_dict(checkpoint['optimizer_state'])
            #opt=None
            print("Loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
            trained=checkpoint['epoch']
            best_error=checkpoint['error']
            
            #print('load success!')
            loss_rec=np.load('/home/lidong/Documents/RSDEN/RSDEN/loss.npy')
            loss_rec=list(loss_rec)
            loss_rec=loss_rec[:3265*trained]
            # for i in range(300):
            #     loss_rec[i][1]=loss_rec[i+300][1]
            for l in range(int(len(loss_rec)/3265)):
                if args.visdom:
                    
                    vis.line(
                        X=torch.ones(1).cpu() * loss_rec[l*3265][0],
                        Y=np.mean(np.array(loss_rec[l*3265:(l+1)*3265])[:,1])*torch.ones(1).cpu(),
                        win=old_window,
                        update='append')
            
    else:

        print("No checkpoint found at '{}'".format(args.resume))
        print('Initialize seperately!')
        checkpoint=torch.load('/home/lidong/Documents/RSDEN/RSDEN/rsnet_nyu_best_model.pkl')
        rsnet.load_state_dict(checkpoint['model_state'])
        trained=checkpoint['epoch']
        print('load success from rsnet %.d'%trained)
        checkpoint=torch.load('/home/lidong/Documents/RSDEN/RSDEN/drnet_nyu_best_model.pkl')
        drnet.load_state_dict(checkpoint['model_state'])
        #optimizer.load_state_dict(checkpoint['optimizer_state'])
        trained=checkpoint['epoch']
        print('load success from drnet %.d'%trained)
        trained=0
        best_error=checkpoint['error']    




    # it should be range(checkpoint[''epoch],args.n_epoch)
    for epoch in range(trained, args.n_epoch):
Example #4
0
def train(args):

    # Setup Augmentations
    data_aug = Compose([RandomRotate(10), RandomHorizontallyFlip()])
    loss_rec = []
    best_error = 2
    # Setup Dataloader
    data_loader = get_loader(args.dataset)
    data_path = get_data_path(args.dataset)
    t_loader = data_loader(data_path,
                           is_transform=True,
                           split='train',
                           img_size=(args.img_rows, args.img_cols),
                           task='region')
    v_loader = data_loader(data_path,
                           is_transform=True,
                           split='test',
                           img_size=(args.img_rows, args.img_cols),
                           task='region')

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(t_loader,
                                  batch_size=args.batch_size,
                                  num_workers=4,
                                  shuffle=True)
    valloader = data.DataLoader(v_loader,
                                batch_size=args.batch_size,
                                num_workers=4)

    # Setup Metrics
    running_metrics = runningScore(n_classes)

    # Setup visdom for visualization
    if args.visdom:
        vis = visdom.Visdom()

        depth_window = vis.image(
            np.random.rand(480, 640),
            opts=dict(title='depth!', caption='depth.'),
        )
        cluster_window = vis.image(
            np.random.rand(480, 640),
            opts=dict(title='cluster!', caption='cluster.'),
        )
        region_window = vis.image(
            np.random.rand(480, 640),
            opts=dict(title='region!', caption='region.'),
        )
        ground_window = vis.image(
            np.random.rand(480, 640),
            opts=dict(title='ground!', caption='ground.'),
        )
        loss_window = vis.line(X=torch.zeros((1, )).cpu(),
                               Y=torch.zeros((1)).cpu(),
                               opts=dict(xlabel='minibatches',
                                         ylabel='Loss',
                                         title='Training Loss',
                                         legend=['Loss']))
        old_window = vis.line(X=torch.zeros((1, )).cpu(),
                              Y=torch.zeros((1)).cpu(),
                              opts=dict(xlabel='minibatches',
                                        ylabel='Loss',
                                        title='Trained Loss',
                                        legend=['Loss']))
    # Setup Model
    model = get_model(args.arch)
    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))
    #model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
    model.cuda()

    # Check if model has custom optimizer / loss
    # modify to adam, modify the learning rate
    if hasattr(model.module, 'optimizer'):
        optimizer = model.module.optimizer
    else:
        # optimizer = torch.optim.Adam(
        #     model.parameters(), lr=args.l_rate,weight_decay=5e-4,betas=(0.9,0.999))
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.l_rate,
                                    momentum=0.90,
                                    weight_decay=5e-4)
    if hasattr(model.module, 'loss'):
        print('Using custom loss')
        loss_fn = model.module.loss
    else:
        loss_fn = log_loss
    trained = 0
    scale = 100

    if args.resume is not None:
        if os.path.isfile(args.resume):
            print("Loading model and optimizer from checkpoint '{}'".format(
                args.resume))
            checkpoint = torch.load(args.resume, map_location='cpu')
            #model_dict=model.state_dict()
            #opt=torch.load('/home/lidong/Documents/RSDEN/RSDEN/exp1/l2/sgd/log/83/rsnet_nyu_best_model.pkl')
            model.load_state_dict(checkpoint['model_state'])
            optimizer.load_state_dict(checkpoint['optimizer_state'])
            #opt=None
            print("Loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            trained = checkpoint['epoch']
            best_error = checkpoint['error']
            #best_error_d=checkpoint['error_d']
            best_error_d = checkpoint['error_d']
            print(best_error)
            print(trained)
            loss_rec = np.load('/home/lidong/Documents/RSDEN/RSDEN/loss.npy')
            loss_rec = list(loss_rec)
            loss_rec = loss_rec[:179 * trained]
            # for i in range(300):
            #     loss_rec[i][1]=loss_rec[i+300][1]
            for l in range(int(len(loss_rec) / 179)):
                if args.visdom:

                    vis.line(
                        X=torch.ones(1).cpu() * loss_rec[l * 179][0],
                        Y=np.mean(
                            np.array(loss_rec[l * 179:(l + 1) * 179])[:, 1]) *
                        torch.ones(1).cpu(),
                        win=old_window,
                        update='append')
            #exit()

    else:
        best_error = 100
        best_error_d = 100
        trained = 0
        print('random initialize')
        """
        print("No checkpoint found at '{}'".format(args.resume))
        print('Initialize from rsn!')
        rsn=torch.load('/home/lidong/Documents/RSDEN/RSDEN/depth_rsn_cluster_nyu2_best_model.pkl',map_location='cpu')
        model_dict=model.state_dict()  
        #print(model_dict)          
        #pre_dict={k: v for k, v in rsn['model_state'].items() if k in model_dict and rsn['model_state'].items()}
        pre_dict={k: v for k, v in rsn.items() if k in model_dict and rsn.items()}
        key=[]
        for k,v in pre_dict.items():
            if v.shape!=model_dict[k].shape:
                key.append(k)
        for k in key:
            pre_dict.pop(k)
        model_dict.update(pre_dict)
        model.load_state_dict(model_dict)
        #trained=rsn['epoch']
        #best_error=rsn['error']
        #best_error_d=checkpoint['error_d']
        #best_error_d=rsn['error_d']
        print('load success!')
        print(best_error)
        print(trained)
        print(best_error_d)
        del rsn
        """

    # it should be range(checkpoint[''epoch],args.n_epoch)
    for epoch in range(trained, args.n_epoch):
        #for epoch in range(0, args.n_epoch):

        #trained
        print('training!')
        model.train()
        for i, (images, labels, regions, segments) in enumerate(trainloader):
            #break
            images = Variable(images.cuda())
            labels = Variable(labels.cuda())
            segments = Variable(segments.cuda())
            regions = Variable(regions.cuda())

            optimizer.zero_grad()

            # depth,feature,loss_var,loss_dis,loss_reg = model(images,segments)
            # loss_d=l2(depth,labels)
            # loss=torch.sum(loss_var)+torch.sum(loss_dis)+0.001*torch.sum(loss_reg)
            # loss=loss/4+loss_d
            # loss/=2
            depth = model(images, segments)
            loss_d = berhu(depth, labels)
            lin = l2(depth, labels)
            loss = loss_d
            loss.backward()
            optimizer.step()
            if loss.item() <= 0.000001:
                feature = feature.data.cpu().numpy().astype('float32')[0, ...]
                feature = np.reshape(
                    feature,
                    [1, feature.shape[0], feature.shape[1], feature.shape[2]])
                feature = np.transpose(feature, [0, 2, 3, 1])
                print(feature.shape)
                #feature = feature[0,...]
                masks = get_instance_masks(feature, 0.7)
                print(masks.shape)
                #cluster = masks[0]
                cluster = np.sum(masks, axis=0)
                cluster = (np.reshape(cluster, [480, 640]).astype('float32') -
                           np.min(cluster)) / (np.max(cluster) -
                                               np.min(cluster) + 1)

                vis.image(
                    cluster,
                    opts=dict(title='cluster!', caption='cluster.'),
                    win=cluster_window,
                )
            if args.visdom:
                vis.line(X=torch.ones(1).cpu() * i + torch.ones(1).cpu() *
                         (epoch - trained) * 179,
                         Y=loss.item() * torch.ones(1).cpu(),
                         win=loss_window,
                         update='append')
                depth = depth.data.cpu().numpy().astype('float32')
                depth = depth[0, :, :, :]
                depth = (np.reshape(depth, [480, 640]).astype('float32') -
                         np.min(depth)) / (np.max(depth) - np.min(depth) + 1)
                vis.image(
                    depth,
                    opts=dict(title='depth!', caption='depth.'),
                    win=depth_window,
                )

                region = regions.data.cpu().numpy().astype('float32')
                region = region[0, ...]
                region = (np.reshape(region, [480, 640]).astype('float32') -
                          np.min(region)) / (np.max(region) - np.min(region) +
                                             1)
                vis.image(
                    region,
                    opts=dict(title='region!', caption='region.'),
                    win=region_window,
                )
                ground = labels.data.cpu().numpy().astype('float32')
                ground = ground[0, :, :]
                ground = (np.reshape(ground, [480, 640]).astype('float32') -
                          np.min(ground)) / (np.max(ground) - np.min(ground) +
                                             1)
                vis.image(
                    ground,
                    opts=dict(title='ground!', caption='ground.'),
                    win=ground_window,
                )
            loss_rec.append([
                i + epoch * 179,
                torch.Tensor([loss.item()]).unsqueeze(0).cpu()
            ])

            # print("data [%d/179/%d/%d] Loss: %.4f loss_var: %.4f loss_dis: %.4f loss_reg: %.4f loss_d: %.4f" % (i, epoch, args.n_epoch,loss.item(), \
            #                     torch.sum(loss_var).item()/4,torch.sum(loss_dis).item()/4,0.001*torch.sum(loss_reg).item()/4,loss_d.item()))
            print("data [%d/179/%d/%d] Loss: %.4f linear: %.4f " %
                  (i, epoch, args.n_epoch, loss.item(), lin.item()))

        if epoch > 30:
            check = 3
        else:
            check = 5
        if epoch > 50:
            check = 2
        if epoch > 70:
            check = 1
        #epoch=3
        if epoch % check == 0:

            print('testing!')
            model.eval()
            loss_ave = []
            loss_d_ave = []
            loss_lin_ave = []
            for i_val, (images_val, labels_val, regions,
                        segments) in tqdm(enumerate(valloader)):
                #print(r'\n')
                images_val = Variable(images_val.cuda(), requires_grad=False)
                labels_val = Variable(labels_val.cuda(), requires_grad=False)
                segments_val = Variable(segments.cuda(), requires_grad=False)
                regions_val = Variable(regions.cuda(), requires_grad=False)
                with torch.no_grad():

                    #depth,feature,loss_var,loss_dis,loss_reg = model(images_val,segments_val)
                    depth = model(images_val, segments_val)
                    # loss=torch.sum(loss_var)+torch.sum(loss_dis)+0.001*torch.sum(loss_reg)
                    # loss=loss/4
                    loss_d = log_loss(input=depth, target=labels_val)
                    loss_d = torch.sqrt(loss_d)
                    loss_lin = l2(depth, labels_val)
                    loss_lin = torch.sqrt(loss_lin)
                    # loss_r=(loss+loss_d)/2
                    # loss_ave.append(loss_r.data.cpu().numpy())
                    loss_d_ave.append(loss_d.data.cpu().numpy())
                    loss_lin_ave.append(loss_lin.data.cpu().numpy())
                    print('error:')
                    print(loss_d_ave[-1])
                    # print(loss_ave[-1])
                    print(loss_lin_ave[-1])
                    #exit()

                    # feature = feature.data.cpu().numpy().astype('float32')[0,...]
                    # feature=np.reshape(feature,[1,feature.shape[0],feature.shape[1],feature.shape[2]])
                    # feature=np.transpose(feature,[0,2,3,1])
                    # #print(feature.shape)
                    # #feature = feature[0,...]
                    # masks=get_instance_masks(feature, 0.7)
                    # #print(len(masks))
                    # cluster = np.array(masks)
                    # cluster=np.sum(masks,axis=0)
                    # cluster = np.reshape(cluster, [480, 640]).astype('float32')/255

                    # vis.image(
                    #     cluster,
                    #     opts=dict(title='cluster!', caption='cluster.'),
                    #     win=cluster_window,
                    # )
                    # ground=segments.data.cpu().numpy().astype('float32')
                    # ground = ground[0, :, :]
                    # ground = (np.reshape(ground, [480, 640]).astype('float32')-np.min(ground))/(np.max(ground)-np.min(ground)+1)
                    # vis.image(
                    #     ground,
                    #     opts=dict(title='ground!', caption='ground.'),
                    #     win=ground_window,
                    # )
            #error=np.mean(loss_ave)
            error_d = np.mean(loss_d_ave)
            error_lin = np.mean(loss_lin_ave)
            #error_rate=np.mean(error_rate)
            print("error_d=%.4f error_lin=%.4f" % (error_d, error_lin))
            #exit()
            #continue
            # if error_d<= best_error:
            #     best_error = error
            #     state = {'epoch': epoch+1,
            #              'model_state': model.state_dict(),
            #              'optimizer_state': optimizer.state_dict(),
            #              'error': error,
            #              'error_d': error_d,
            #              }
            #     torch.save(state, "{}_{}_best_model.pkl".format(
            #         args.arch, args.dataset))
            #     print('save success')
            # np.save('/home/lidong/Documents/RSDEN/RSDEN/loss.npy',loss_rec)
            if error_lin <= best_error:
                best_error = error_lin
                state = {
                    'epoch': epoch + 1,
                    'model_state': model.state_dict(),
                    'optimizer_state': optimizer.state_dict(),
                    'error': error_lin,
                    'error_d': error_d,
                }
                torch.save(
                    state, "depth_{}_{}_best_model.pkl".format(
                        args.arch, args.dataset))
                print('save success')
            np.save('/home/lidong/Documents/RSDEN/RSDEN/loss.npy', loss_rec)
        if epoch % 15 == 0:
            #best_error = error
            state = {
                'epoch': epoch + 1,
                'model_state': model.state_dict(),
                'optimizer_state': optimizer.state_dict(),
                'error': error_lin,
                'error_d': error_d,
            }
            torch.save(
                state,
                "depth_{}_{}_{}_model.pkl".format(args.arch, args.dataset,
                                                  str(epoch)))
            print('save success')
Example #5
0
def train(args):

    # Setup Augmentations
    data_aug = Compose([RandomRotate(10), RandomHorizontallyFlip()])
    loss_rec = []
    best_error = 2
    # Setup Dataloader
    data_loader = get_loader(args.dataset)
    data_path = get_data_path(args.dataset)
    t_loader = data_loader(data_path,
                           is_transform=True,
                           split='train_region',
                           img_size=(args.img_rows, args.img_cols))
    v_loader = data_loader(data_path,
                           is_transform=True,
                           split='test_region',
                           img_size=(args.img_rows, args.img_cols))

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(t_loader,
                                  batch_size=args.batch_size,
                                  num_workers=4,
                                  shuffle=True)
    valloader = data.DataLoader(v_loader,
                                batch_size=args.batch_size,
                                num_workers=4)

    # Setup Metrics
    running_metrics = runningScore(n_classes)

    # Setup visdom for visualization
    if args.visdom:
        vis = visdom.Visdom()
        old_window = vis.line(X=torch.zeros((1, )).cpu(),
                              Y=torch.zeros((1)).cpu(),
                              opts=dict(xlabel='minibatches',
                                        ylabel='Loss',
                                        title='Trained Loss',
                                        legend=['Loss']))
        loss_window1 = vis.line(X=torch.zeros((1, )).cpu(),
                                Y=torch.zeros((1)).cpu(),
                                opts=dict(xlabel='minibatches',
                                          ylabel='Loss',
                                          title='Training Loss1',
                                          legend=['Loss1']))
        loss_window2 = vis.line(X=torch.zeros((1, )).cpu(),
                                Y=torch.zeros((1)).cpu(),
                                opts=dict(xlabel='minibatches',
                                          ylabel='Loss',
                                          title='Training Loss2',
                                          legend=['Loss']))
        loss_window3 = vis.line(X=torch.zeros((1, )).cpu(),
                                Y=torch.zeros((1)).cpu(),
                                opts=dict(xlabel='minibatches',
                                          ylabel='Loss',
                                          title='Training Loss3',
                                          legend=['Loss3']))
        pre_window1 = vis.image(
            np.random.rand(480, 640),
            opts=dict(title='predict1!', caption='predict1.'),
        )
        pre_window2 = vis.image(
            np.random.rand(480, 640),
            opts=dict(title='predict2!', caption='predict2.'),
        )
        pre_window3 = vis.image(
            np.random.rand(480, 640),
            opts=dict(title='predict3!', caption='predict3.'),
        )
        support_window = vis.image(
            np.random.rand(480, 640),
            opts=dict(title='support!', caption='support.'),
        )
        ground_window = vis.image(
            np.random.rand(480, 640),
            opts=dict(title='ground!', caption='ground.'),
        )
    # Setup Model
    model = get_model(args.arch)
    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))
    #model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
    model.cuda()

    # Check if model has custom optimizer / loss
    # modify to adam, modify the learning rate
    if hasattr(model.module, 'optimizer'):
        optimizer = model.module.optimizer
    else:
        # optimizer = torch.optim.Adam(
        #     model.parameters(), lr=args.l_rate,weight_decay=5e-4,betas=(0.9,0.999))
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.l_rate,
                                    momentum=0.99,
                                    weight_decay=5e-4)
    if hasattr(model.module, 'loss'):
        print('Using custom loss')
        loss_fn = model.module.loss
    else:
        loss_fn = log_r
    trained = 0
    scale = 100

    if args.resume is not None:
        if os.path.isfile(args.resume):
            print("Loading model and optimizer from checkpoint '{}'".format(
                args.resume))
            checkpoint = torch.load(args.resume)
            model_dict = model.state_dict()
            pre_dict = {
                k: v
                for k, v in checkpoint['model_state'].items()
                if k in model_dict
            }

            model_dict.update(pre_dict)
            #print(model_dict['module.conv1.weight'].shape)
            model_dict['module.conv1.weight'] = torch.cat([
                model_dict['module.conv1.weight'],
                torch.reshape(model_dict['module.conv1.weight'][:, 3, :, :],
                              [64, 1, 7, 7])
            ], 1)
            #print(model_dict['module.conv1.weight'].shape)
            model.load_state_dict(model_dict)
            #model.load_state_dict(checkpoint['model_state'])
            optimizer.load_state_dict(checkpoint['optimizer_state'])
            print("Loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            trained = checkpoint['epoch']
            print('load success!')
            #optimizer.load_state_dict(checkpoint['optimizer_state'])
            #opt=None
            opti_dict = optimizer.state_dict()
            #pre_dict={k: v for k, v in checkpoint['optimizer_state'].items() if k in opti_dict}
            pre_dict = checkpoint['optimizer_state']
            # for k,v in pre_dict.items():
            #     print(k)
            #     if k=='state':
            #         #print(v.type)
            #         for a,b in v.items():
            #             print(a)
            #             print(b['momentum_buffer'].shape)
            # return 0
            opti_dict.update(pre_dict)
            # for k,v in opti_dict.items():
            #     print(k)
            #     if k=='state':
            #         #print(v.type)
            #         for a,b in v.items():
            #             if a==140011149405280:
            #                 print(b['momentum_buffer'].shape)
            #print(opti_dict['state'][140011149405280]['momentum_buffer'].shape)
            opti_dict['state'][139629660382048]['momentum_buffer'] = torch.cat(
                [
                    opti_dict['state'][139629660382048]['momentum_buffer'],
                    torch.reshape(
                        opti_dict['state'][139629660382048]['momentum_buffer']
                        [:, 3, :, :], [64, 1, 7, 7])
                ], 1)
            #print(opti_dict['module.conv1.weight'].shape)
            optimizer.load_state_dict(opti_dict)
            best_error = checkpoint['error'] + 0.15

            # #print('load success!')
            # loss_rec=np.load('/home/lidong/Documents/RSDEN/RSDEN/loss.npy')
            # loss_rec=list(loss_rec)
            # loss_rec=loss_rec[:816*trained]
            # # for i in range(300):
            # #     loss_rec[i][1]=loss_rec[i+300][1]
            # for l in range(int(len(loss_rec)/816)):
            #     if args.visdom:
            #         #print(np.array(loss_rec[l])[1:])
            #         # vis.line(
            #         #     X=torch.ones(1).cpu() * loss_rec[l][0],
            #         #     Y=np.mean(np.array(loss_rec[l])[1:])*torch.ones(1).cpu(),
            #         #     win=old_window,
            #         #     update='append')
            #         vis.line(
            #             X=torch.ones(1).cpu() * loss_rec[l*816][0],
            #             Y=np.mean(np.array(loss_rec[l*816:(l+1)*816])[:,1])*torch.ones(1).cpu(),
            #             win=old_window,
            #             update='append')

    else:

        print("No checkpoint found at '{}'".format(args.resume))
        print('Initialize from resnet34!')
        #resnet34=torch.load('/home/lidong/Documents/RSDEN/RSDEN/resnet34-333f7ec4.pth')
        resnet34 = torch.load(
            '/home/lidong/Documents/RSDEN/RSDEN/rsnet_nyu_best_model.pkl')
        model_dict = model.state_dict()
        # for k,v in resnet34['model_state'].items():
        #     print(k)
        pre_dict = {
            k: v
            for k, v in resnet34['model_state'].items() if k in model_dict
        }
        # for k,v in pre_dict.items():e
        #     print(k)

        model_dict.update(pre_dict)
        model_dict['module.conv1.weight'] = torch.cat([
            model_dict['module.conv1.weight'],
            torch.mean(model_dict['module.conv1.weight'], 1, keepdim=True)
        ], 1)
        # model_dict['module.conv1.weight']=torch.transpose(model_dict['module.conv1.weight'],1,2)
        # model_dict['module.conv1.weight']=torch.transpose(model_dict['module.conv1.weight'],2,4)
        model.load_state_dict(model_dict)
        print('load success!')
        best_error = 1
        trained = 0

    # it should be range(checkpoint[''epoch],args.n_epoch)
    for epoch in range(trained, args.n_epoch):
        #for epoch in range(0, args.n_epoch):

        #trained
        print('training!')
        model.train()

        for i, (images, labels, segments) in enumerate(trainloader):
            images = Variable(images.cuda())
            labels = Variable(labels.cuda())
            segments = Variable(segments.cuda())
            # print(segments.shape)
            # print(images.shape)
            images = torch.cat([images, segments], 1)
            images = torch.cat([images, segments], 1)
            optimizer.zero_grad()
            outputs = model(images)
            #outputs=torch.reshape(outputs,[outputs.shape[0],1,outputs.shape[1],outputs.shape[2]])
            #outputs=outputs
            loss = loss_fn(input=outputs, target=labels)
            out = 0.2 * loss[0] + 0.3 * loss[1] + 0.5 * loss[2]
            # print('training:'+str(i)+':learning_rate'+str(loss.data.cpu().numpy()))
            out.backward()
            optimizer.step()
            # print(torch.Tensor([loss.data[0]]).unsqueeze(0).cpu())
            #print(loss.item()*torch.ones(1).cpu())
            #nyu2_train:246,nyu2_all:816
            if args.visdom:
                vis.line(X=torch.ones(1).cpu() * i + torch.ones(1).cpu() *
                         (epoch - trained) * 816,
                         Y=loss[0].item() * torch.ones(1).cpu(),
                         win=loss_window1,
                         update='append')
                vis.line(X=torch.ones(1).cpu() * i + torch.ones(1).cpu() *
                         (epoch - trained) * 816,
                         Y=loss[1].item() * torch.ones(1).cpu(),
                         win=loss_window2,
                         update='append')
                vis.line(X=torch.ones(1).cpu() * i + torch.ones(1).cpu() *
                         (epoch - trained) * 816,
                         Y=loss[2].item() * torch.ones(1).cpu(),
                         win=loss_window3,
                         update='append')
                pre = outputs[0].data.cpu().numpy().astype('float32')
                pre = pre[0, :, :]
                #pre = np.argmax(pre, 0)
                pre = (np.reshape(pre, [480, 640]).astype('float32') -
                       np.min(pre)) / (np.max(pre) - np.min(pre))
                #pre = pre/np.max(pre)
                # print(type(pre[0,0]))
                vis.image(
                    pre,
                    opts=dict(title='predict1!', caption='predict1.'),
                    win=pre_window1,
                )
                pre = outputs[1].data.cpu().numpy().astype('float32')
                pre = pre[0, :, :]
                #pre = np.argmax(pre, 0)
                pre = (np.reshape(pre, [480, 640]).astype('float32') -
                       np.min(pre)) / (np.max(pre) - np.min(pre))
                #pre = pre/np.max(pre)
                # print(type(pre[0,0]))
                vis.image(
                    pre,
                    opts=dict(title='predict2!', caption='predict2.'),
                    win=pre_window2,
                )
                pre = outputs[2].data.cpu().numpy().astype('float32')
                pre = pre[0, :, :]
                #pre = np.argmax(pre, 0)
                pre = (np.reshape(pre, [480, 640]).astype('float32') -
                       np.min(pre)) / (np.max(pre) - np.min(pre))
                #pre = pre/np.max(pre)
                # print(type(pre[0,0]))
                vis.image(
                    pre,
                    opts=dict(title='predict3!', caption='predict3.'),
                    win=pre_window3,
                )
                ground = labels.data.cpu().numpy().astype('float32')
                #print(ground.shape)
                ground = ground[0, :, :]
                ground = (np.reshape(ground, [480, 640]).astype('float32') -
                          np.min(ground)) / (np.max(ground) - np.min(ground))
                vis.image(
                    ground,
                    opts=dict(title='ground!', caption='ground.'),
                    win=ground_window,
                )
                ground = segments.data.cpu().numpy().astype('float32')
                #print(ground.shape)
                ground = ground[0, :, :]
                ground = (np.reshape(ground, [480, 640]).astype('float32') -
                          np.min(ground)) / (np.max(ground) - np.min(ground))
                vis.image(
                    ground,
                    opts=dict(title='support!', caption='support.'),
                    win=support_window,
                )

            loss_rec.append([
                i + epoch * 816,
                torch.Tensor([loss[0].item()]).unsqueeze(0).cpu(),
                torch.Tensor([loss[1].item()]).unsqueeze(0).cpu(),
                torch.Tensor([loss[2].item()]).unsqueeze(0).cpu()
            ])
            print("data [%d/816/%d/%d] Loss1: %.4f Loss2: %.4f Loss3: %.4f" %
                  (i, epoch, args.n_epoch, loss[0].item(), loss[1].item(),
                   loss[2].item()))

        #epoch=3
        if epoch % 1 == 0:
            print('testing!')
            model.train()
            error_lin = []
            error_log = []
            error_va = []
            error_rate = []
            error_absrd = []
            error_squrd = []
            thre1 = []
            thre2 = []
            thre3 = []

            for i_val, (images_val, labels_val,
                        segments) in tqdm(enumerate(valloader)):
                print(r'\n')
                images_val = Variable(images_val.cuda(), requires_grad=False)
                labels_val = Variable(labels_val.cuda(), requires_grad=False)
                segments = Variable(segments.cuda())
                images_val = torch.cat([images_val, segments], 1)
                images_val = torch.cat([images_val, segments], 1)
                with torch.no_grad():
                    outputs = model(images_val)
                    pred = outputs[2].data.cpu().numpy()
                    gt = labels_val.data.cpu().numpy()
                    ones = np.ones((gt.shape))
                    zeros = np.zeros((gt.shape))
                    pred = np.reshape(pred, (gt.shape))
                    #gt=np.reshape(gt,[4,480,640])
                    dis = np.square(gt - pred)
                    error_lin.append(np.sqrt(np.mean(dis)))
                    dis = np.square(np.log(gt) - np.log(pred))
                    error_log.append(np.sqrt(np.mean(dis)))
                    alpha = np.mean(np.log(gt) - np.log(pred))
                    dis = np.square(np.log(pred) - np.log(gt) + alpha)
                    error_va.append(np.mean(dis) / 2)
                    dis = np.mean(np.abs(gt - pred)) / gt
                    error_absrd.append(np.mean(dis))
                    dis = np.square(gt - pred) / gt
                    error_squrd.append(np.mean(dis))
                    thelt = np.where(pred / gt > gt / pred, pred / gt,
                                     gt / pred)
                    thres1 = 1.25

                    thre1.append(np.mean(np.where(thelt < thres1, ones,
                                                  zeros)))
                    thre2.append(
                        np.mean(np.where(thelt < thres1 * thres1, ones,
                                         zeros)))
                    thre3.append(
                        np.mean(
                            np.where(thelt < thres1 * thres1 * thres1, ones,
                                     zeros)))
                    #a=thre1[i_val]
                    #error_rate.append(np.mean(np.where(dis<0.6,ones,zeros)))
                    print(
                        "error_lin=%.4f,error_log=%.4f,error_va=%.4f,error_absrd=%.4f,error_squrd=%.4f,thre1=%.4f,thre2=%.4f,thre3=%.4f"
                        % (error_lin[i_val], error_log[i_val], error_va[i_val],
                           error_absrd[i_val], error_squrd[i_val],
                           thre1[i_val], thre2[i_val], thre3[i_val]))
            error = np.mean(error_lin)
            #error_rate=np.mean(error_rate)
            print("error=%.4f" % (error))

            if error <= best_error:
                best_error = error
                state = {
                    'epoch': epoch + 1,
                    'model_state': model.state_dict(),
                    'optimizer_state': optimizer.state_dict(),
                    'error': error,
                }
                torch.save(
                    state,
                    "{}_{}_best_model.pkl".format(args.arch, args.dataset))
                print('save success')
            np.save('/home/lidong/Documents/RSDEN/RSDEN//loss.npy', loss_rec)
        if epoch % 10 == 0:
            #best_error = error
            state = {
                'epoch': epoch + 1,
                'model_state': model.state_dict(),
                'optimizer_state': optimizer.state_dict(),
                'error': error,
            }
            torch.save(
                state, "{}_{}_{}_model.pkl".format(args.arch, args.dataset,
                                                   str(epoch)))
            print('save success')
Example #6
0
def train(args):

    # Setup Augmentations
    data_aug = Compose([RandomRotate(10), RandomHorizontallyFlip()])
    loss_rec = []
    best_error = 2
    # Setup Dataloader
    data_loader = get_loader(args.dataset)
    data_path = get_data_path(args.dataset)
    t_loader = data_loader(data_path,
                           is_transform=True,
                           split='train_region',
                           img_size=(args.img_rows, args.img_cols),
                           task='visualize')
    v_loader = data_loader(data_path,
                           is_transform=True,
                           split='visual',
                           img_size=(args.img_rows, args.img_cols),
                           task='visualize')

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(t_loader,
                                  batch_size=args.batch_size,
                                  num_workers=2,
                                  shuffle=True)
    valloader = data.DataLoader(v_loader,
                                batch_size=args.batch_size,
                                num_workers=2)

    # Setup Metrics
    running_metrics = runningScore(n_classes)

    # Setup visdom for visualization

    cuda0 = torch.device('cuda:0')
    cuda1 = torch.device('cuda:1')
    cuda2 = torch.device('cuda:2')
    cuda3 = torch.device('cuda:3')
    # Setup Model
    rsnet = get_model('rsnet')
    rsnet = torch.nn.DataParallel(rsnet, device_ids=[0])
    rsnet.cuda(cuda0)
    drnet = get_model('drnet')
    drnet = torch.nn.DataParallel(drnet, device_ids=[2])
    drnet.cuda(cuda2)
    parameters = list(rsnet.parameters()) + list(drnet.parameters())
    # Check if model has custom optimizer / loss
    # modify to adam, modify the learning rate
    if hasattr(drnet.module, 'optimizer'):
        optimizer = drnet.module.optimizer
    else:
        # optimizer = torch.optim.Adam(
        #     model.parameters(), lr=args.l_rate,weight_decay=5e-4,betas=(0.9,0.999))
        optimizer = torch.optim.SGD(parameters,
                                    lr=args.l_rate,
                                    momentum=0.99,
                                    weight_decay=5e-4)
    if hasattr(rsnet.module, 'loss'):
        print('Using custom loss')
        loss_fn = rsnet.module.loss
    else:
        loss_fn = l1_r
    trained = 0
    scale = 100

    if args.resume is not None:
        if os.path.isfile(args.resume):
            print("Loading model and optimizer from checkpoint '{}'".format(
                args.resume))
            checkpoint = torch.load(args.resume)
            #model_dict=model.state_dict()
            #opt=torch.load('/home/lidong/Documents/RSDEN/RSDEN/exp1/l2/sgd/log/83/rsnet_nyu_best_model.pkl')
            model.load_state_dict(checkpoint['model_state'])
            #optimizer.load_state_dict(checkpoint['optimizer_state'])
            #opt=None
            print("Loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            trained = checkpoint['epoch']
            best_error = checkpoint['error']

            #print('load success!')
            loss_rec = np.load('/home/lidong/Documents/RSDEN/RSDEN/loss.npy')
            loss_rec = list(loss_rec)
            loss_rec = loss_rec[:1632 * trained]
            # for i in range(300):
            #     loss_rec[i][1]=loss_rec[i+300][1]
            for l in range(int(len(loss_rec) / 1632)):
                if args.visdom:

                    vis.line(
                        X=torch.ones(1).cpu() * loss_rec[l * 1632][0],
                        Y=np.mean(
                            np.array(loss_rec[l * 1632:(l + 1) * 1632])[:, 1])
                        * torch.ones(1).cpu(),
                        win=old_window,
                        update='append')

    else:

        print("No checkpoint found at '{}'".format(args.resume))
        print('Initialize seperately!')
        checkpoint = torch.load(
            '/home/lidong/Documents/RSDEN/RSDEN/exp1/region/trained/rsnet_nyu_best_model.pkl'
        )
        rsnet.load_state_dict(checkpoint['model_state'])
        trained = checkpoint['epoch']
        print('load success from rsnet %.d' % trained)
        best_error = checkpoint['error']
        checkpoint = torch.load(
            '//home/lidong/Documents/RSDEN/RSDEN/exp1/seg/drnet_nyu_best_model.pkl'
        )
        drnet.load_state_dict(checkpoint['model_state'])
        #optimizer.load_state_dict(checkpoint['optimizer_state'])
        trained = checkpoint['epoch']
        print('load success from drnet %.d' % trained)
        trained = 0

    min_loss = 10
    samples = []

    # it should be range(checkpoint[''epoch],args.n_epoch)
    for epoch in range(trained, args.n_epoch):

        rsnet.train()
        drnet.train()

        if epoch % 1 == 0:
            print('testing!')
            rsnet.train()
            drnet.train()
            error_lin = []
            error_log = []
            error_va = []
            error_rate = []
            error_absrd = []
            error_squrd = []
            thre1 = []
            thre2 = []
            thre3 = []

            for i_val, (images, labels, segments,
                        sample) in tqdm(enumerate(valloader)):
                #print(r'\n')
                images = images.cuda(cuda2)
                labels = labels.cuda(cuda2)
                segments = segments.cuda(cuda2)
                optimizer.zero_grad()
                #print(i_val)

                with torch.no_grad():
                    #region_support = rsnet(images)
                    coarse_depth = torch.cat([images, segments], 1)
                    #coarse_depth=torch.cat([coarse_depth,segments],1)
                    outputs = drnet(coarse_depth)
                    #print(outputs[2].item())
                    pred = [
                        outputs[0].data.cpu().numpy(),
                        outputs[1].data.cpu().numpy(),
                        outputs[2].data.cpu().numpy()
                    ]
                    pred = np.array(pred)
                    #print(pred.shape)
                    #pred=region_support.data.cpu().numpy()
                    gt = labels.data.cpu().numpy()
                    ones = np.ones((gt.shape))
                    zeros = np.zeros((gt.shape))
                    pred = np.reshape(
                        pred, (gt.shape[0], gt.shape[1], gt.shape[2], 3))
                    #pred=np.reshape(pred,(gt.shape))
                    print(np.max(pred))
                    #print(gt.shape)
                    #print(pred.shape)
                    #gt=np.reshape(gt,[4,480,640])
                    dis = np.square(gt - pred[:, :, :, 2])
                    #dis=np.square(gt-pred)
                    loss = np.sqrt(np.mean(dis))
                    #print(min_loss)
                    if min_loss > 0:
                        #print(loss)
                        min_loss = loss
                        #pre=pred[:,:,0]
                        #region_support=region_support.item()
                        #rgb=rgb
                        #segments=segments
                        #labels=labels.item()
                        #sample={'loss':loss,'rgb':rgb,'region_support':region_support,'ground_r':segments,'ground_d':labels}
                        #samples.append(sample)
                        #pred=pred.item()
                        #pred=pred[0,:,:]
                        #pred=pred/np.max(pred)*255
                        #pred=pred.astype(np.uint8)
                        #print(pred.shape)
                        #cv2.imwrite('/home/lidong/Documents/RSDEN/RSDEN/exp1/pred/seg%.d.png'%(i_val),pred)
                        np.save(
                            '/home/lidong/Documents/RSDEN/RSDEN/exp1/pred/seg%.d.npy'
                            % (i_val), pred)
                        np.save(
                            '/home/lidong/Documents/RSDEN/RSDEN/exp1/visual/seg%.d.npy'
                            % (i_val), sample)
            break
Example #7
0
def train(args):

    # Setup Augmentations
    data_aug = Compose([RandomRotate(10), RandomHorizontallyFlip()])

    # Setup Dataloader
    data_loader = get_loader(args.dataset)
    data_path = get_data_path(args.dataset)
    t_loader = data_loader(data_path,
                           is_transform=True,
                           split='train',
                           img_size=(args.img_rows, args.img_cols))
    v_loader = data_loader(data_path,
                           is_transform=True,
                           split='test',
                           img_size=(args.img_rows, args.img_cols))

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(t_loader,
                                  batch_size=args.batch_size,
                                  num_workers=8,
                                  shuffle=True)
    valloader = data.DataLoader(v_loader,
                                batch_size=args.batch_size,
                                num_workers=8)

    # Setup Metrics
    running_metrics = runningScore(n_classes)

    # Setup visdom for visualization
    if args.visdom:
        vis = visdom.Visdom()

        loss_window = vis.line(X=torch.zeros((1, )).cpu(),
                               Y=torch.zeros((1)).cpu(),
                               opts=dict(xlabel='minibatches',
                                         ylabel='Loss',
                                         title='Training Loss',
                                         legend=['Loss']))
        pre_window = vis.image(
            np.random.rand(480, 640),
            opts=dict(title='predict!', caption='predict.'),
        )
        ground_window = vis.image(
            np.random.rand(480, 640),
            opts=dict(title='ground!', caption='ground.'),
        )
    # Setup Model
    model = get_model(args.arch, n_classes)
    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))
    #model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
    model.cuda()

    # Check if model has custom optimizer / loss
    # modify to adam, modify the learning rate
    if hasattr(model.module, 'optimizer'):
        optimizer = model.module.optimizer
    else:
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.l_rate,
                                    momentum=0.99,
                                    weight_decay=5e-4)

    if hasattr(model.module, 'loss'):
        print('Using custom loss')
        loss_fn = model.module.loss
    else:
        loss_fn = l1
    trained = 0
    scale = 100
    if args.resume is not None:
        if os.path.isfile(args.resume):
            print("Loading model and optimizer from checkpoint '{}'".format(
                args.resume))
            checkpoint = torch.load(args.resume)
            model.load_state_dict(checkpoint['model_state'])
            optimizer.load_state_dict(checkpoint['optimizer_state'])
            print("Loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            trained = checkpoint['epoch']
        else:
            print("No checkpoint found at '{}'".format(args.resume))

    best_error = 100
    best_rate = 100
    # it should be range(checkpoint[''epoch],args.n_epoch)
    for epoch in range(trained, args.n_epoch):
        print('training!')
        model.train()
        for i, (images, labels) in enumerate(trainloader):
            images = Variable(images.cuda())
            labels = Variable(labels.cuda())

            optimizer.zero_grad()
            outputs = model(images)
            #outputs=outputs
            loss = loss_fn(input=outputs, target=labels)
            # print('training:'+str(i)+':learning_rate'+str(loss.data.cpu().numpy()))
            loss.backward()
            optimizer.step()
            # print(torch.Tensor([loss.data[0]]).unsqueeze(0).cpu())
            if args.visdom:
                vis.line(X=torch.ones(1).cpu() * i,
                         Y=torch.Tensor([loss.data[0]]).unsqueeze(0).cpu()[0],
                         win=loss_window,
                         update='append')
                pre = outputs.data.cpu().numpy().astype('float32')
                pre = pre[0, :, :, :]
                #pre = np.argmax(pre, 0)
                pre = np.reshape(pre,
                                 [480, 640]).astype('float32') / np.max(pre)
                #pre = pre/np.max(pre)
                # print(type(pre[0,0]))
                vis.image(
                    pre,
                    opts=dict(title='predict!', caption='predict.'),
                    win=pre_window,
                )
                ground = labels.data.cpu().numpy().astype('float32')
                #print(ground.shape)
                ground = ground[0, :, :]
                ground = np.reshape(
                    ground, [480, 640]).astype('float32') / np.max(ground)
                vis.image(
                    ground,
                    opts=dict(title='ground!', caption='ground.'),
                    win=ground_window,
                )
            # if i%100==0:
            #     state = {'epoch': epoch,
            #              'model_state': model.state_dict(),
            #              'optimizer_state' : optimizer.state_dict(),}
            #     torch.save(state, "training_{}_{}_model.pkl".format(i, args.dataset))
            # if loss.data[0]/weight<100:
            # 	weight=100
            # else if(loss.data[0]/weight<100)
            print("data [%d/503/%d/%d] Loss: %.4f" %
                  (i, epoch, args.n_epoch, loss.data[0]))
        print('testing!')
        model.eval()
        error = []
        error_rate = []
        ones = np.ones([480, 640])
        zeros = np.zeros([480, 640])
        for i_val, (images_val, labels_val) in tqdm(enumerate(valloader)):
            images_val = Variable(images_val.cuda(), volatile=True)
            labels_val = Variable(labels_val.cuda(), volatile=True)

            outputs = model(images_val)
            pred = outputs.data.cpu().numpy()
            gt = labels_val.data.cpu().numpy()
            pred = np.reshape(pred, [4, 480, 640])
            gt = np.reshape(gt, [4, 480, 640])
            dis = np.abs(gt - pred)
            error.append(np.mean(dis))
            error_rate.append(np.mean(np.where(dis < 0.05, ones, zeros)))
        error = np.mean(error)
        error_rate = np.mean(error_rate)
        print("error=%.4f,error < 5 cm : %.4f" % (error, error_rate))
        if error <= best_error:
            best_error = error
            state = {
                'epoch': epoch + 1,
                'model_state': model.state_dict(),
                'optimizer_state': optimizer.state_dict(),
            }
            torch.save(state,
                       "{}_{}_best_model.pkl".format(args.arch, args.dataset))
def train(args):

    # Setup Augmentations
    data_aug = Compose([RandomRotate(10), RandomHorizontallyFlip()])
    loss_rec = []
    best_error = 2
    # Setup Dataloader
    data_loader = get_loader(args.dataset)
    data_path = get_data_path(args.dataset)
    t_loader = data_loader(data_path,
                           is_transform=True,
                           split='train_region',
                           img_size=(args.img_rows, args.img_cols),
                           task='all')
    v_loader = data_loader(data_path,
                           is_transform=True,
                           split='test_region',
                           img_size=(args.img_rows, args.img_cols),
                           task='all')

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(t_loader,
                                  batch_size=args.batch_size,
                                  num_workers=2,
                                  shuffle=True)
    valloader = data.DataLoader(v_loader,
                                batch_size=args.batch_size,
                                num_workers=2)

    # Setup Metrics
    running_metrics = runningScore(n_classes)

    # Setup visdom for visualization
    if args.visdom:
        vis = visdom.Visdom()
        # old_window = vis.line(X=torch.zeros((1,)).cpu(),
        #                        Y=torch.zeros((1)).cpu(),
        #                        opts=dict(xlabel='minibatches',
        #                                  ylabel='Loss',
        #                                  title='Trained Loss',
        #                                  legend=['Loss'])
        a_window = vis.line(X=torch.zeros((1, )).cpu(),
                            Y=torch.zeros((1)).cpu(),
                            opts=dict(xlabel='minibatches',
                                      ylabel='Loss',
                                      title='Region Loss1',
                                      legend=['Region']))
        loss_window1 = vis.line(X=torch.zeros((1, )).cpu(),
                                Y=torch.zeros((1)).cpu(),
                                opts=dict(xlabel='minibatches',
                                          ylabel='Loss',
                                          title='Training Loss1',
                                          legend=['Loss1']))
        loss_window2 = vis.line(X=torch.zeros((1, )).cpu(),
                                Y=torch.zeros((1)).cpu(),
                                opts=dict(xlabel='minibatches',
                                          ylabel='Loss',
                                          title='Training Loss2',
                                          legend=['Loss']))
        loss_window3 = vis.line(X=torch.zeros((1, )).cpu(),
                                Y=torch.zeros((1)).cpu(),
                                opts=dict(xlabel='minibatches',
                                          ylabel='Loss',
                                          title='Training Loss3',
                                          legend=['Loss3']))
        pre_window1 = vis.image(
            np.random.rand(480, 640),
            opts=dict(title='predict1!', caption='predict1.'),
        )
        pre_window2 = vis.image(
            np.random.rand(480, 640),
            opts=dict(title='predict2!', caption='predict2.'),
        )
        pre_window3 = vis.image(
            np.random.rand(480, 640),
            opts=dict(title='predict3!', caption='predict3.'),
        )

        ground_window = vis.image(np.random.rand(480, 640),
                                  opts=dict(title='ground!',
                                            caption='ground.')),
        region_window = vis.image(
            np.random.rand(480, 640),
            opts=dict(title='region!', caption='region.'),
        )
    cuda0 = torch.device('cuda:0')
    cuda1 = torch.device('cuda:1')
    cuda2 = torch.device('cuda:2')
    cuda3 = torch.device('cuda:3')
    # Setup Model
    rsnet = get_model('rsnet')
    rsnet = torch.nn.DataParallel(rsnet, device_ids=[0, 1])
    rsnet.cuda(cuda0)
    drnet = get_model('drnet')
    drnet = torch.nn.DataParallel(drnet, device_ids=[2, 3])
    drnet.cuda(cuda2)
    parameters = list(rsnet.parameters()) + list(drnet.parameters())
    # Check if model has custom optimizer / loss
    # modify to adam, modify the learning rate
    if hasattr(drnet.module, 'optimizer'):
        optimizer = drnet.module.optimizer
    else:
        # optimizer = torch.optim.Adam(
        #     rsnet.parameters(), lr=args.l_rate,weight_decay=5e-4,betas=(0.9,0.999))
        optimizer = torch.optim.SGD(rsnet.parameters(),
                                    lr=args.l_rate,
                                    momentum=0.99,
                                    weight_decay=5e-4)
    if hasattr(rsnet.module, 'loss'):
        print('Using custom loss')
        loss_fn = rsnet.module.loss
    else:
        loss_fn = log_r
        #loss_fn = region_r
    trained = 0
    scale = 100

    if args.resume is not None:
        if os.path.isfile(args.resume):
            print("Loading model and optimizer from checkpoint '{}'".format(
                args.resume))
            checkpoint = torch.load(
                '/home/lidong/Documents/RSDEN/RSDEN/rsnet_nyu_best_model.pkl')
            rsnet.load_state_dict(checkpoint['model_state'])
            #optimizer.load_state_dict(checkpoint['optimizer_state'])
            trained = checkpoint['epoch']
            best_error = checkpoint['error']
            print('load success from rsnet %.d' % trained)
            checkpoint = torch.load(
                '/home/lidong/Documents/RSDEN/RSDEN/drnet_nyu_best_model.pkl')
            drnet.load_state_dict(checkpoint['model_state'])
            #optimizer.load_state_dict(checkpoint['optimizer_state'])
            trained = checkpoint['epoch']
            print('load success from drnet %.d' % trained)

            #print('load success!')
            loss_rec = np.load('/home/lidong/Documents/RSDEN/RSDEN/loss.npy')
            loss_rec = list(loss_rec)
            loss_rec = loss_rec[:1632 * trained]
            # for i in range(300):
            #     loss_rec[i][1]=loss_rec[i+300][1]
            for l in range(int(len(loss_rec) / 1632)):
                if args.visdom:

                    vis.line(
                        X=torch.ones(1).cpu() * loss_rec[l * 1632][0],
                        Y=np.mean(
                            np.array(loss_rec[l * 1632:(l + 1) * 1632])[:, 1])
                        * torch.ones(1).cpu(),
                        win=old_window,
                        update='append')

    else:

        print("No checkpoint found at '{}'".format(args.resume))
        print('Initialize seperately!')
        checkpoint = torch.load(
            '/home/lidong/Documents/RSDEN/RSDEN/rsnet_nyu_135_model.pkl')
        rsnet.load_state_dict(checkpoint['model_state'])
        trained = checkpoint['epoch']
        best_error = checkpoint['error']
        print(best_error)
        print('load success from rsnet %.d' % trained)
        checkpoint = torch.load(
            '/home/lidong/Documents/RSDEN/RSDEN/drnet_nyu_135_model.pkl')

        # model_dict=drnet.state_dict()
        # pre_dict={k: v for k, v in checkpoint['model_state'].items() if k in model_dict}

        # model_dict.update(pre_dict)
        # #print(model_dict['module.conv1.weight'].shape)
        # model_dict['module.conv1.weight']=torch.cat([model_dict['module.conv1.weight'],torch.reshape(model_dict['module.conv1.weight'][:,3,:,:],[64,1,7,7])],1)
        # #print(model_dict['module.conv1.weight'].shape)
        # drnet.load_state_dict(model_dict)
        drnet.load_state_dict(checkpoint['model_state'])
        #optimizer.load_state_dict(checkpoint['optimizer_state'])
        trained = checkpoint['epoch']
        print('load success from drnet %.d' % trained)
        #trained=0
        loss_rec = []
        #loss_rec=np.load('/home/lidong/Documents/RSDEN/RSDEN/loss.npy')
        #loss_rec=list(loss_rec)
        #loss_rec=loss_rec[:1632*trained]
        #average_loss=checkpoint['error']
        # opti_dict=optimizer.state_dict()
        # #pre_dict={k: v for k, v in checkpoint['optimizer_state'].items() if k in opti_dict}
        # pre_dict=checkpoint['optimizer_state']
        # # for k,v in pre_dict.items():
        # #     print(k)
        # #     if k=='state':
        # #         #print(v.type)
        # #         for a,b in v.items():
        # #             print(a)
        # #             print(b['momentum_buffer'].shape)
        # #return 0
        # opti_dict.update(pre_dict)
        # # for k,v in opti_dict.items():
        # #     print(k)
        # #     if k=='state':
        # #         #print(v.type)
        # #         for a,b in v.items():
        # #             if a==140011149405280:
        # #                 print(b['momentum_buffer'].shape)
        # #print(opti_dict['state'][140011149405280]['momentum_buffer'].shape)
        # opti_dict['state'][140011149405280]['momentum_buffer']=torch.cat([opti_dict['state'][140011149405280]['momentum_buffer'],torch.reshape(opti_dict['state'][140011149405280]['momentum_buffer'][:,3,:,:],[64,1,7,7])],1)
        # #print(opti_dict['module.conv1.weight'].shape)
        # optimizer.load_state_dict(opti_dict)

    # it should be range(checkpoint[''epoch],args.n_epoch)
    for epoch in range(trained, args.n_epoch):
        #for epoch in range(0, args.n_epoch):

        #trained
        print('training!')
        rsnet.train()
        drnet.train()
        for i, (images, labels, segments) in enumerate(trainloader):
            images = images.cuda()
            labels = labels.cuda(cuda2)
            segments = segments.cuda(cuda2)
            #for error_sample in range(10):
            optimizer.zero_grad()
            #with torch.autograd.enable_grad():
            region_support = rsnet(images)
            #with torch.autograd.enable_grad():
            coarse_depth = torch.cat([images, region_support], 1)
            coarse_depth = torch.cat([coarse_depth, region_support], 1)
            #with torch.no_grad():
            outputs = drnet(coarse_depth)
            #outputs.append(region_support)
            #outputs=torch.reshape(outputs,[outputs.shape[0],1,outputs.shape[1],outputs.shape[2]])
            #outputs=outputs
            loss = loss_fn(input=outputs, target=labels)
            out = 0.2 * loss[0] + 0.3 * loss[1] + 0.5 * loss[2]
            #out=out
            a = l1(input=region_support, target=labels.to(cuda0))
            #a=region_log(input=region_support,target=labels.to(cuda0),instance=segments.to(cuda0)).to(cuda2)
            b = log_loss(region_support, labels.to(cuda0)).item()
            #out=0.8*out+0.02*a
            #a.backward()
            # print('training:'+str(i)+':learning_rate'+str(loss.data.cpu().numpy()))
            out.backward()
            optimizer.step()
            # print('out:%.4f,error_sample:%d'%(out.item(),error_sample))
            # if i==0:
            #     average_loss=(average_loss+out.item())/2
            #     break
            # if out.item()<average_loss/i:
            #     break
            # print(torch.Tensor([loss.data[0]]).unsqueeze(0).cpu())
            #print(loss.item()*torch.ones(1).cpu())
            #nyu2_train:246,nyu2_all:1632
            if args.visdom:
                vis.line(X=torch.ones(1).cpu() * i + torch.ones(1).cpu() *
                         (epoch - trained) * 1632,
                         Y=a.item() * torch.ones(1).cpu(),
                         win=a_window,
                         update='append')
                vis.line(X=torch.ones(1).cpu() * i + torch.ones(1).cpu() *
                         (epoch - trained) * 1632,
                         Y=loss[0].item() * torch.ones(1).cpu(),
                         win=loss_window1,
                         update='append')
                vis.line(X=torch.ones(1).cpu() * i + torch.ones(1).cpu() *
                         (epoch - trained) * 1632,
                         Y=loss[1].item() * torch.ones(1).cpu(),
                         win=loss_window2,
                         update='append')
                vis.line(X=torch.ones(1).cpu() * i + torch.ones(1).cpu() *
                         (epoch - trained) * 1632,
                         Y=loss[2].item() * torch.ones(1).cpu(),
                         win=loss_window3,
                         update='append')
                pre = outputs[0].data.cpu().numpy().astype('float32')
                pre = pre[0, :, :]
                #pre = np.argmax(pre, 0)
                pre = (np.reshape(pre, [480, 640]).astype('float32') -
                       np.min(pre)) / (np.max(pre) - np.min(pre))
                #pre = pre/np.max(pre)
                # print(type(pre[0,0]))
                vis.image(
                    pre,
                    opts=dict(title='predict1!', caption='predict1.'),
                    win=pre_window1,
                )
                pre = outputs[1].data.cpu().numpy().astype('float32')
                pre = pre[0, :, :]
                #pre = np.argmax(pre, 0)
                pre = (np.reshape(pre, [480, 640]).astype('float32') -
                       np.min(pre)) / (np.max(pre) - np.min(pre))
                #pre = pre/np.max(pre)
                # print(type(pre[0,0]))
                vis.image(
                    pre,
                    opts=dict(title='predict2!', caption='predict2.'),
                    win=pre_window2,
                )
                pre = outputs[2].data.cpu().numpy().astype('float32')
                pre = pre[0, :, :]
                #pre = np.argmax(pre, 0)
                pre = (np.reshape(pre, [480, 640]).astype('float32') -
                       np.min(pre)) / (np.max(pre) - np.min(pre))
                #pre = pre/np.max(pre)
                # print(type(pre[0,0]))
                vis.image(
                    pre,
                    opts=dict(title='predict3!', caption='predict3.'),
                    win=pre_window3,
                )
                ground = labels.data.cpu().numpy().astype('float32')
                #print(ground.shape)
                ground = ground[0, :, :]
                ground = (np.reshape(ground, [480, 640]).astype('float32') -
                          np.min(ground)) / (np.max(ground) - np.min(ground))
                vis.image(
                    ground,
                    opts=dict(title='ground!', caption='ground.'),
                    win=ground_window,
                )
                region_vis = region_support.data.cpu().numpy().astype(
                    'float32')
                #print(ground.shape)
                region_vis = region_vis[0, :, :]
                region_vis = (
                    np.reshape(region_vis, [480, 640]).astype('float32') -
                    np.min(region_vis)) / (np.max(region_vis) -
                                           np.min(region_vis))
                vis.image(
                    region_vis,
                    opts=dict(title='region_vis!', caption='region_vis.'),
                    win=region_window,
                )
            #average_loss+=out.item()
            loss_rec.append([
                i + epoch * 1632,
                torch.Tensor([loss[0].item()]).unsqueeze(0).cpu(),
                torch.Tensor([loss[1].item()]).unsqueeze(0).cpu(),
                torch.Tensor([loss[2].item()]).unsqueeze(0).cpu()
            ])
            print(
                "data [%d/1632/%d/%d]region:%.4f,%.4f Loss1: %.4f Loss2: %.4f Loss3: %.4f out:%.4f "
                % (i, epoch, args.n_epoch, a.item(), b, loss[0].item(),
                   loss[1].item(), loss[2].item(), out.item()))

        #average_loss=average_loss/816
        if epoch > 50:
            check = 1
        else:
            check = 1
        if epoch > 70:
            check = 1
        if epoch % check == 0:
            print('testing!')
            rsnet.train()
            drnet.train()
            error_lin = []
            error_log = []
            error_va = []
            error_rate = []
            error_absrd = []
            error_squrd = []
            thre1 = []
            thre2 = []
            thre3 = []

            for i_val, (images, labels,
                        segments) in tqdm(enumerate(valloader)):
                #print(r'\n')
                images = images.cuda()
                labels = labels.cuda()
                optimizer.zero_grad()
                print(i_val)

                with torch.no_grad():
                    region_support = rsnet(images)
                    coarse_depth = torch.cat([images, region_support], 1)
                    coarse_depth = torch.cat([coarse_depth, region_support], 1)
                    outputs = drnet(coarse_depth)
                    pred = outputs[2].data.cpu().numpy()
                    gt = labels.data.cpu().numpy()
                    ones = np.ones((gt.shape))
                    zeros = np.zeros((gt.shape))
                    pred = np.reshape(pred, (gt.shape))
                    #gt=np.reshape(gt,[4,480,640])
                    dis = np.square(gt - pred)
                    error_lin.append(np.sqrt(np.mean(dis)))
                    dis = np.square(np.log(gt) - np.log(pred))
                    error_log.append(np.sqrt(np.mean(dis)))
                    alpha = np.mean(np.log(gt) - np.log(pred))
                    dis = np.square(np.log(pred) - np.log(gt) + alpha)
                    error_va.append(np.mean(dis) / 2)
                    dis = np.mean(np.abs(gt - pred)) / gt
                    error_absrd.append(np.mean(dis))
                    dis = np.square(gt - pred) / gt
                    error_squrd.append(np.mean(dis))
                    thelt = np.where(pred / gt > gt / pred, pred / gt,
                                     gt / pred)
                    thres1 = 1.25

                    thre1.append(np.mean(np.where(thelt < thres1, ones,
                                                  zeros)))
                    thre2.append(
                        np.mean(np.where(thelt < thres1 * thres1, ones,
                                         zeros)))
                    thre3.append(
                        np.mean(
                            np.where(thelt < thres1 * thres1 * thres1, ones,
                                     zeros)))
                    #a=thre1[i_val]
                    #error_rate.append(np.mean(np.where(dis<0.6,ones,zeros)))
                    print(
                        "error_lin=%.4f,error_log=%.4f,error_va=%.4f,error_absrd=%.4f,error_squrd=%.4f,thre1=%.4f,thre2=%.4f,thre3=%.4f"
                        % (error_lin[i_val], error_log[i_val], error_va[i_val],
                           error_absrd[i_val], error_squrd[i_val],
                           thre1[i_val], thre2[i_val], thre3[i_val]))
                    # if i_val > 219/check:
                    #     break
            error = np.mean(error_lin)
            #error_rate=np.mean(error_rate)
            print("error=%.4f" % (error))

            if error <= best_error:
                best_error = error
                state = {
                    'epoch': epoch + 1,
                    'model_state': rsnet.state_dict(),
                    'optimizer_state': optimizer.state_dict(),
                    'error': error,
                }
                torch.save(
                    state,
                    "{}_{}_best_model.pkl".format('rsnet', args.dataset))
                state = {
                    'epoch': epoch + 1,
                    'model_state': drnet.state_dict(),
                    'optimizer_state': optimizer.state_dict(),
                    'error': error,
                }
                torch.save(
                    state,
                    "{}_{}_best_model.pkl".format('drnet', args.dataset))
                print('save success')
            np.save('/home/lidong/Documents/RSDEN/RSDEN//loss.npy', loss_rec)
        if epoch % 3 == 0:
            #best_error = error
            state = {
                'epoch': epoch + 1,
                'model_state': rsnet.state_dict(),
                'optimizer_state': optimizer.state_dict(),
                'error': error,
            }
            torch.save(
                state, "{}_{}_{}_model.pkl".format('rsnet', args.dataset,
                                                   str(epoch)))
            state = {
                'epoch': epoch + 1,
                'model_state': drnet.state_dict(),
                'optimizer_state': optimizer.state_dict(),
                'error': error,
            }
            torch.save(
                state, "{}_{}_{}_model.pkl".format('drnet', args.dataset,
                                                   str(epoch)))
            print('save success')
def train(args):

    # Setup Augmentations
    data_aug = Compose([RandomRotate(10), RandomHorizontallyFlip()])
    loss_rec = []
    best_error = 2
    # Setup Dataloader
    data_loader = get_loader(args.dataset)
    data_path = get_data_path(args.dataset)
    t_loader = data_loader(data_path,
                           is_transform=True,
                           split='train_region',
                           img_size=(args.img_rows, args.img_cols),
                           task='all')
    v_loader = data_loader(data_path,
                           is_transform=True,
                           split='test_region',
                           img_size=(args.img_rows, args.img_cols),
                           task='all')

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(t_loader,
                                  batch_size=args.batch_size,
                                  num_workers=2,
                                  shuffle=True)
    valloader = data.DataLoader(v_loader,
                                batch_size=args.batch_size,
                                num_workers=2)

    # Setup Metrics
    running_metrics = runningScore(n_classes)

    # Setup visdom for visualization
    if args.visdom:
        vis = visdom.Visdom()
        # old_window = vis.line(X=torch.zeros((1,)).cpu(),
        #                        Y=torch.zeros((1)).cpu(),
        #                        opts=dict(xlabel='minibatches',
        #                                  ylabel='Loss',
        #                                  title='Trained Loss',
        #                                  legend=['Loss']))
        loss_window1 = vis.line(X=torch.zeros((1, )).cpu(),
                                Y=torch.zeros((1)).cpu(),
                                opts=dict(xlabel='minibatches',
                                          ylabel='Loss',
                                          title='Training Loss1',
                                          legend=['Loss1']))
        loss_window2 = vis.line(X=torch.zeros((1, )).cpu(),
                                Y=torch.zeros((1)).cpu(),
                                opts=dict(xlabel='minibatches',
                                          ylabel='Loss',
                                          title='Training Loss2',
                                          legend=['Loss']))
        loss_window3 = vis.line(X=torch.zeros((1, )).cpu(),
                                Y=torch.zeros((1)).cpu(),
                                opts=dict(xlabel='minibatches',
                                          ylabel='Loss',
                                          title='Training Loss3',
                                          legend=['Loss3']))
        pre_window1 = vis.image(
            np.random.rand(480, 640),
            opts=dict(title='predict1!', caption='predict1.'),
        )
        pre_window2 = vis.image(
            np.random.rand(480, 640),
            opts=dict(title='predict2!', caption='predict2.'),
        )
        pre_window3 = vis.image(
            np.random.rand(480, 640),
            opts=dict(title='predict3!', caption='predict3.'),
        )

        ground_window = vis.image(
            np.random.rand(480, 640),
            opts=dict(title='ground!', caption='ground.'),
        )
    cuda0 = torch.device('cuda:0')
    cuda1 = torch.device('cuda:1')
    cuda2 = torch.device('cuda:2')
    cuda3 = torch.device('cuda:3')
    # Setup Model
    rsnet = get_model('rsnet')
    rsnet = torch.nn.DataParallel(rsnet, device_ids=[0, 1])
    rsnet.cuda(cuda0)
    drnet = get_model('drnet')
    drnet = torch.nn.DataParallel(drnet, device_ids=[2, 3])
    drnet.cuda(cuda2)
    parameters = list(rsnet.parameters()) + list(drnet.parameters())
    # Check if model has custom optimizer / loss
    # modify to adam, modify the learning rate
    if hasattr(drnet.module, 'optimizer'):
        optimizer = drnet.module.optimizer
    else:
        # optimizer = torch.optim.Adam(
        #     model.parameters(), lr=args.l_rate,weight_decay=5e-4,betas=(0.9,0.999))
        optimizer = torch.optim.SGD(parameters,
                                    lr=args.l_rate,
                                    momentum=0.99,
                                    weight_decay=5e-4)
    if hasattr(rsnet.module, 'loss'):
        print('Using custom loss')
        loss_fn = rsnet.module.loss
    else:
        loss_fn = l1_r
    trained = 0
    scale = 100

    if args.resume is not None:
        if os.path.isfile(args.resume):
            print("Loading model and optimizer from checkpoint '{}'".format(
                args.resume))
            checkpoint = torch.load(args.resume)
            #model_dict=model.state_dict()
            #opt=torch.load('/home/lidong/Documents/RSDEN/RSDEN/exp1/l2/sgd/log/83/rsnet_nyu_best_model.pkl')
            model.load_state_dict(checkpoint['model_state'])
            #optimizer.load_state_dict(checkpoint['optimizer_state'])
            #opt=None
            print("Loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            trained = checkpoint['epoch']
            best_error = checkpoint['error']

            #print('load success!')
            loss_rec = np.load('/home/lidong/Documents/RSDEN/RSDEN/loss.npy')
            loss_rec = list(loss_rec)
            loss_rec = loss_rec[:1632 * trained]
            # for i in range(300):
            #     loss_rec[i][1]=loss_rec[i+300][1]
            for l in range(int(len(loss_rec) / 1632)):
                if args.visdom:

                    vis.line(
                        X=torch.ones(1).cpu() * loss_rec[l * 1632][0],
                        Y=np.mean(
                            np.array(loss_rec[l * 1632:(l + 1) * 1632])[:, 1])
                        * torch.ones(1).cpu(),
                        win=old_window,
                        update='append')

    else:

        print("No checkpoint found at '{}'".format(args.resume))
        print('Initialize seperately!')
        checkpoint = torch.load(
            '/home/lidong/Documents/RSDEN/RSDEN/rsnet_nyu_best_model.pkl')
        rsnet.load_state_dict(checkpoint['model_state'])
        trained = checkpoint['epoch']
        print('load success from rsnet %.d' % trained)
        best_error = checkpoint['error']
        checkpoint = torch.load(
            '/home/lidong/Documents/RSDEN/RSDEN/drnet_nyu_best_model.pkl')
        drnet.load_state_dict(checkpoint['model_state'])
        #optimizer.load_state_dict(checkpoint['optimizer_state'])
        trained = checkpoint['epoch']
        print('load success from drnet %.d' % trained)
        trained = 0

    # it should be range(checkpoint[''epoch],args.n_epoch)
    for epoch in range(trained, args.n_epoch):

        rsnet.train()
        drnet.train()

        if epoch % 1 == 0:
            print('testing!')
            rsnet.train()
            drnet.train()
            error_lin = []
            error_log = []
            error_va = []
            error_rate = []
            error_absrd = []
            error_squrd = []
            thre1 = []
            thre2 = []
            thre3 = []

            for i_val, (images, labels,
                        segments) in tqdm(enumerate(valloader)):
                #print(r'\n')
                images = images.cuda()
                labels = labels.cuda()
                optimizer.zero_grad()
                print(i_val)

                with torch.no_grad():
                    region_support = rsnet(images)
                    coarse_depth = torch.cat([images, region_support], 1)
                    coarse_depth = torch.cat([coarse_depth, region_support], 1)
                    outputs = drnet(coarse_depth)
                    pred = outputs[2].data.cpu().numpy()
                    gt = labels.data.cpu().numpy()
                    ones = np.ones((gt.shape))
                    zeros = np.zeros((gt.shape))
                    pred = np.reshape(pred, (gt.shape))
                    #gt=np.reshape(gt,[4,480,640])
                    dis = np.square(gt - pred)
                    error_lin.append(np.sqrt(np.mean(dis)))
                    dis = np.square(np.log(gt) - np.log(pred))
                    error_log.append(np.sqrt(np.mean(dis)))
                    alpha = np.mean(np.log(gt) - np.log(pred))
                    dis = np.square(np.log(pred) - np.log(gt) + alpha)
                    error_va.append(np.mean(dis) / 2)
                    dis = np.mean(np.abs(gt - pred)) / gt
                    error_absrd.append(np.mean(dis))
                    dis = np.square(gt - pred) / gt
                    error_squrd.append(np.mean(dis))
                    thelt = np.where(pred / gt > gt / pred, pred / gt,
                                     gt / pred)
                    thres1 = 1.25

                    thre1.append(np.mean(np.where(thelt < thres1, ones,
                                                  zeros)))
                    thre2.append(
                        np.mean(np.where(thelt < thres1 * thres1, ones,
                                         zeros)))
                    thre3.append(
                        np.mean(
                            np.where(thelt < thres1 * thres1 * thres1, ones,
                                     zeros)))
                    #a=thre1[i_val]
                    #error_rate.append(np.mean(np.where(dis<0.6,ones,zeros)))
                    print(
                        "error_lin=%.4f,error_log=%.4f,error_va=%.4f,error_absrd=%.4f,error_squrd=%.4f,thre1=%.4f,thre2=%.4f,thre3=%.4f"
                        % (error_lin[i_val], error_log[i_val], error_va[i_val],
                           error_absrd[i_val], error_squrd[i_val],
                           thre1[i_val], thre2[i_val], thre3[i_val]))
            np.save('/home/lidong/Documents/RSDEN/RSDEN//error_train.npy', [
                error_lin[i_val], error_log[i_val], error_va[i_val],
                error_absrd[i_val], error_squrd[i_val], thre1[i_val],
                thre2[i_val], thre3[i_val]
            ])
            error_lin = np.mean(error_lin)
            error_log = np.mean(error_log)
            error_va = np.mean(error_va)
            error_absrd = np.mean(error_absrd)
            error_squrd = np.mean(error_squrd)
            thre1 = np.mean(thre1)
            thre2 = np.mean(thre2)
            thre3 = np.mean(thre3)

            print('Final Result!')
            print(
                "error_lin=%.4f,error_log=%.4f,error_va=%.4f,error_absrd=%.4f,error_squrd=%.4f,thre1=%.4f,thre2=%.4f,thre3=%.4f"
                % (error_lin, error_log, error_va, error_absrd, error_squrd,
                   thre1, thre2, thre3))
            break
Example #10
0
def train(args):

    # Setup Augmentations
    data_aug = Compose([RandomRotate(10), RandomHorizontallyFlip()])
    loss_rec = []
    best_error = 2
    # Setup Dataloader
    data_loader = get_loader(args.dataset)
    data_path = get_data_path(args.dataset)
    t_loader = data_loader(data_path,
                           is_transform=True,
                           split='train',
                           img_size=(args.img_rows, args.img_cols),
                           task='region')
    v_loader = data_loader(data_path,
                           is_transform=True,
                           split='test',
                           img_size=(args.img_rows, args.img_cols),
                           task='region')

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(t_loader,
                                  batch_size=args.batch_size,
                                  num_workers=4,
                                  shuffle=True)
    valloader = data.DataLoader(v_loader,
                                batch_size=args.batch_size,
                                num_workers=4)

    # Setup Metrics
    running_metrics = runningScore(n_classes)

    # Setup visdom for visualization
    if args.visdom:
        vis = visdom.Visdom()

        depth_window = vis.image(
            np.random.rand(480, 640),
            opts=dict(title='depth!', caption='depth.'),
        )
        mask_window = vis.image(
            np.random.rand(480, 640),
            opts=dict(title='mask!', caption='mask.'),
        )
        region_window = vis.image(
            np.random.rand(480, 640),
            opts=dict(title='region!', caption='region.'),
        )
        ground_window = vis.image(
            np.random.rand(480, 640),
            opts=dict(title='ground!', caption='ground.'),
        )
        loss_window = vis.line(X=torch.zeros((1, )).cpu(),
                               Y=torch.zeros((1)).cpu(),
                               opts=dict(xlabel='minibatches',
                                         ylabel='Loss',
                                         title='Training Loss',
                                         legend=['Loss']))
        old_window = vis.line(X=torch.zeros((1, )).cpu(),
                              Y=torch.zeros((1)).cpu(),
                              opts=dict(xlabel='minibatches',
                                        ylabel='Loss',
                                        title='Trained Loss',
                                        legend=['Loss']))
    # Setup Model
    model = get_model(args.arch)
    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))
    #model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
    model.cuda()

    # Check if model has custom optimizer / loss
    # modify to adam, modify the learning rate
    if hasattr(model.module, 'optimizer'):
        optimizer = model.module.optimizer
    else:
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.l_rate,
                                     weight_decay=5e-4,
                                     betas=(0.9, 0.999))
        # optimizer = torch.optim.SGD(
        #     model.parameters(), lr=args.l_rate,momentum=0.90, weight_decay=5e-4)
    if hasattr(model.module, 'loss'):
        print('Using custom loss')
        loss_fn = model.module.loss
    else:
        loss_fn = log_loss
    trained = 0
    scale = 100

    if args.resume is not None:
        if os.path.isfile(args.resume):
            print("Loading model and optimizer from checkpoint '{}'".format(
                args.resume))
            checkpoint = torch.load(args.resume, map_location='cpu')
            #model_dict=model.state_dict()
            #opt=torch.load('/home/lidong/Documents/RSDEN/RSDEN/exp1/l2/sgd/log/83/rsnet_nyu_best_model.pkl')
            model.load_state_dict(checkpoint['model_state'])
            optimizer.load_state_dict(checkpoint['optimizer_state'])
            #opt=None
            print("Loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            trained = checkpoint['epoch']
            best_error = checkpoint['error']
            print(best_error)
            loss_rec = np.load('/home/lidong/Documents/RSDEN/RSDEN/loss.npy')
            loss_rec = list(loss_rec)
            loss_rec = loss_rec[:179 * trained]
            # for i in range(300):
            #     loss_rec[i][1]=loss_rec[i+300][1]
            for l in range(int(len(loss_rec) / 179)):
                if args.visdom:

                    vis.line(
                        X=torch.ones(1).cpu() * loss_rec[l * 179][0],
                        Y=np.mean(
                            np.array(loss_rec[l * 179:(l + 1) * 179])[:, 1]) *
                        torch.ones(1).cpu(),
                        win=old_window,
                        update='append')

    else:

        print("No checkpoint found at '{}'".format(args.resume))
        print('Initialize from rsn!')
        rsn = torch.load(
            '/home/lidong/Documents/RSDEN/RSDEN/rsn_mask_nyu2_best_model.pkl',
            map_location='cpu')
        model_dict = model.state_dict()
        #print(model_dict)
        pre_dict = {
            k: v
            for k, v in rsn['model_state'].items()
            if k in model_dict and rsn['model_state'].items()
        }
        key = []
        for k, v in pre_dict.items():
            if v.shape != model_dict[k].shape:
                key.append(k)
        for k in key:
            pre_dict.pop(k)
        model_dict.update(pre_dict)
        model.load_state_dict(model_dict)
        print('load success!')
        best_error = 100
        trained = 0
        del rsn

    # it should be range(checkpoint[''epoch],args.n_epoch)
    for epoch in range(trained, args.n_epoch):
        #for epoch in range(0, args.n_epoch):

        #trained
        print('training!')
        model.train()
        for i, (images, labels, regions, segments) in enumerate(trainloader):
            #break
            images = Variable(images.cuda())
            labels = Variable(labels.cuda())
            segments = Variable(segments.cuda())
            regions = Variable(regions.cuda())
            #break
            optimizer.zero_grad()
            #outputs,mask = model(images)
            mask = model(images)
            outputs = regions
            #loss_d = region_log(outputs,labels,segments)
            segments = torch.reshape(
                segments, [mask.shape[0], mask.shape[2], mask.shape[3]])
            #loss_m = mask_loss(input=mask,target=segments)
            loss_m = mask_loss_region(mask, segments)
            #region=segments
            #print(loss_m)
            #mask_map=torch.argmax(mask)
            #loss_r,region= region_loss(outputs,mask,regions,segments)
            #loss_c=loss_d
            # print('training:'+str(i)+':learning_rate'+str(loss.data.cpu().numpy()))
            #loss=0.5*loss_d+0.5*(loss_m+loss_r)
            #break
            #loss_d=loss_r
            #loss=0.25*loss_r+0.5*loss_m+0.25*loss_d
            loss_d = loss_m
            loss_r = loss_m
            region = segments
            loss = loss_m
            loss.backward()
            optimizer.step()
            # print(torch.Tensor([loss.data[0]]).unsqueeze(0).cpu())
            #print(loss.item()*torch.ones(1).cpu())
            #nyu2_train:246,nyu2_all:179
            if args.visdom:
                vis.line(X=torch.ones(1).cpu() * i + torch.ones(1).cpu() *
                         (epoch - trained) * 179,
                         Y=loss.item() * torch.ones(1).cpu(),
                         win=loss_window,
                         update='append')
                depth = outputs.data.cpu().numpy().astype('float32')
                depth = depth[0, :, :, :]
                depth = (np.reshape(depth, [480, 640]).astype('float32') -
                         np.min(depth)) / (np.max(depth) - np.min(depth) + 1)
                vis.image(
                    depth,
                    opts=dict(title='depth!', caption='depth.'),
                    win=depth_window,
                )
                mask = torch.argmax(mask,
                                    dim=1).data.cpu().numpy().astype('float32')
                mask = mask[0, ...]
                mask = (np.reshape(mask, [480, 640]).astype('float32') -
                        np.min(mask)) / (np.max(mask) - np.min(mask) + 1)
                vis.image(
                    mask,
                    opts=dict(title='mask!', caption='mask.'),
                    win=mask_window,
                )
                region = region.data.cpu().numpy().astype('float32')
                region = region[0, ...]
                region = (np.reshape(region, [480, 640]).astype('float32') -
                          np.min(region)) / (np.max(region) - np.min(region) +
                                             1)
                vis.image(
                    region,
                    opts=dict(title='region!', caption='region.'),
                    win=region_window,
                )
                ground = regions.data.cpu().numpy().astype('float32')
                ground = ground[0, :, :]
                ground = (np.reshape(ground, [480, 640]).astype('float32') -
                          np.min(ground)) / (np.max(ground) - np.min(ground) +
                                             1)
                vis.image(
                    ground,
                    opts=dict(title='ground!', caption='ground.'),
                    win=ground_window,
                )

            loss_rec.append([
                i + epoch * 179,
                torch.Tensor([loss.item()]).unsqueeze(0).cpu()
            ])
            print(
                "data [%d/179/%d/%d] Loss: %.4f Lossd: %.4f Lossm: %.4f Lossr: %.4f"
                % (i, epoch, args.n_epoch, loss.item(), loss_d.item(),
                   loss_m.item(), loss_r.item()))
        if epoch > 30:
            check = 5
        else:
            check = 10
        if epoch > 50:
            check = 3
        if epoch > 70:
            check = 1
        #epoch=3
        if epoch % check == 0:

            print('testing!')
            model.eval()
            loss_ave = []

            for i_val, (images_val, labels_val, regions,
                        segments) in tqdm(enumerate(valloader)):
                #print(r'\n')
                images_val = Variable(images_val.cuda(), requires_grad=False)
                labels_val = Variable(labels_val.cuda(), requires_grad=False)
                segments_val = Variable(segments.cuda(), requires_grad=False)
                regions_val = Variable(regions.cuda(), requires_grad=False)
                with torch.no_grad():
                    #outputs,mask = model(images_val)
                    mask = model(images_val)
                    outputs = regions
                    #region= region_generation(outputs,mask,regions_val,segments_val)
                    #loss_d = l2(input=region, target=regions_val)
                    segments_val = torch.reshape(
                        segments_val,
                        [mask.shape[0], mask.shape[2], mask.shape[3]])
                    #loss_r,region= region_loss(outputs,mask,regions_val,segments_val)
                    loss_r = mask_loss_region(mask, segments_val)
                    loss_ave.append(loss_r.data.cpu().numpy())
                    print(loss_ave[-1])
                    #exit()
            error = np.mean(loss_ave)
            #error_rate=np.mean(error_rate)
            print("error=%.4f" % (error))

            if error <= best_error:
                best_error = error
                state = {
                    'epoch': epoch + 1,
                    'model_state': model.state_dict(),
                    'optimizer_state': optimizer.state_dict(),
                    'error': error,
                }
                torch.save(
                    state,
                    "{}_{}_best_model.pkl".format(args.arch, args.dataset))
                print('save success')
            np.save('/home/lidong/Documents/RSDEN/RSDEN/loss.npy', loss_rec)
        if epoch % 15 == 0:
            #best_error = error
            state = {
                'epoch': epoch + 1,
                'model_state': model.state_dict(),
                'optimizer_state': optimizer.state_dict(),
                'error': error,
            }
            torch.save(
                state, "{}_{}_{}_model.pkl".format(args.arch, args.dataset,
                                                   str(epoch)))
            print('save success')
Example #11
0
def train(args):

    # Setup Augmentations
    data_aug = Compose([RandomRotate(10), RandomHorizontallyFlip()])
    loss_rec = []
    best_error = 2
    # Setup Dataloader
    data_loader = get_loader(args.dataset)
    data_path = get_data_path(args.dataset)
    t_loader = data_loader(data_path,
                           is_transform=True,
                           split='train_region',
                           img_size=(args.img_rows, args.img_cols))
    v_loader = data_loader(data_path,
                           is_transform=True,
                           split='test_region',
                           img_size=(args.img_rows, args.img_cols))

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(t_loader,
                                  batch_size=args.batch_size,
                                  num_workers=2,
                                  shuffle=True)
    valloader = data.DataLoader(v_loader,
                                batch_size=args.batch_size,
                                num_workers=2)

    # Setup Metrics
    running_metrics = runningScore(n_classes)

    # Setup visdom for visualization
    if args.visdom:
        vis = visdom.Visdom()
        # old_window = vis.line(X=torch.zeros((1,)).cpu(),
        #                        Y=torch.zeros((1)).cpu(),
        #                        opts=dict(xlabel='minibatches',
        #                                  ylabel='Loss',
        #                                  title='Trained Loss',
        #                                  legend=['Loss']))
        loss_window1 = vis.line(X=torch.zeros((1, )).cpu(),
                                Y=torch.zeros((1)).cpu(),
                                opts=dict(xlabel='minibatches',
                                          ylabel='Loss',
                                          title='Training Loss1',
                                          legend=['Loss1']))
        loss_window2 = vis.line(X=torch.zeros((1, )).cpu(),
                                Y=torch.zeros((1)).cpu(),
                                opts=dict(xlabel='minibatches',
                                          ylabel='Loss',
                                          title='Training Loss2',
                                          legend=['Loss']))
        loss_window3 = vis.line(X=torch.zeros((1, )).cpu(),
                                Y=torch.zeros((1)).cpu(),
                                opts=dict(xlabel='minibatches',
                                          ylabel='Loss',
                                          title='Training Loss3',
                                          legend=['Loss3']))
        pre_window1 = vis.image(
            np.random.rand(480, 640),
            opts=dict(title='predict1!', caption='predict1.'),
        )
        pre_window2 = vis.image(
            np.random.rand(480, 640),
            opts=dict(title='predict2!', caption='predict2.'),
        )
        pre_window3 = vis.image(
            np.random.rand(480, 640),
            opts=dict(title='predict3!', caption='predict3.'),
        )

        ground_window = vis.image(
            np.random.rand(480, 640),
            opts=dict(title='ground!', caption='ground.'),
        )
    cuda0 = torch.device('cuda:0')
    cuda1 = torch.device('cuda:1')
    cuda2 = torch.device('cuda:2')
    cuda3 = torch.device('cuda:3')
    # Setup Model
    rsnet = get_model('rsnet')
    rsnet = torch.nn.DataParallel(rsnet, device_ids=[0])
    rsnet.to(cuda0)
    drnet = get_model('drnet')
    drnet = torch.nn.DataParallel(drnet, device_ids=[1])
    drnet.to(cuda1)
    parameters = list(rsnet.parameters()) + list(drnet.parameters())
    # Check if model has custom optimizer / loss
    # modify to adam, modify the learning rate
    if hasattr(drnet.module, 'optimizer'):
        optimizer = drnet.module.optimizer
    else:
        # optimizer = torch.optim.Adam(
        #     model.parameters(), lr=args.l_rate,weight_decay=5e-4,betas=(0.9,0.999))
        optimizer = torch.optim.SGD(parameters,
                                    lr=args.l_rate,
                                    momentum=0.99,
                                    weight_decay=5e-4)
    if hasattr(rsnet.module, 'loss'):
        print('Using custom loss')
        loss_fn = rsnet.module.loss
    else:
        loss_fn = l1_r
    trained = 0
    scale = 100

    if args.resume is not None:
        if os.path.isfile(args.resume):
            print("Loading model and optimizer from checkpoint '{}'".format(
                args.resume))
            checkpoint = torch.load(args.resume)
            #model_dict=model.state_dict()
            #opt=torch.load('/home/lidong/Documents/RSDEN/RSDEN/exp1/l2/sgd/log/83/rsnet_nyu_best_model.pkl')
            model.load_state_dict(checkpoint['model_state'])
            #optimizer.load_state_dict(checkpoint['optimizer_state'])
            #opt=None
            print("Loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            trained = checkpoint['epoch']
            best_error = checkpoint['error']

            #print('load success!')
            loss_rec = np.load('/home/lidong/Documents/RSDEN/RSDEN/loss.npy')
            loss_rec = list(loss_rec)
            loss_rec = loss_rec[:3265 * trained]
            # for i in range(300):
            #     loss_rec[i][1]=loss_rec[i+300][1]
            for l in range(int(len(loss_rec) / 3265)):
                if args.visdom:

                    vis.line(
                        X=torch.ones(1).cpu() * loss_rec[l * 3265][0],
                        Y=np.mean(
                            np.array(loss_rec[l * 3265:(l + 1) * 3265])[:, 1])
                        * torch.ones(1).cpu(),
                        win=old_window,
                        update='append')

    else:

        print("No checkpoint found at '{}'".format(args.resume))
        print('Initialize seperately!')
        checkpoint = torch.load(
            '/home/lidong/Documents/RSDEN/RSDEN/rsnet_nyu_best_model.pkl')
        rsnet.load_state_dict(checkpoint['model_state'])
        trained = checkpoint['epoch']
        print('load success from rsnet %.d' % trained)
        checkpoint = torch.load(
            '/home/lidong/Documents/RSDEN/RSDEN/drnet_nyu_best_model.pkl')
        drnet.load_state_dict(checkpoint['model_state'])
        #optimizer.load_state_dict(checkpoint['optimizer_state'])
        trained = checkpoint['epoch']
        print('load success from drnet %.d' % trained)
        trained = 0
        best_error = 1

    # it should be range(checkpoint[''epoch],args.n_epoch)
    for epoch in range(trained, args.n_epoch):
        #for epoch in range(0, args.n_epoch):

        #trained
        print('training!')
        rsnet.train()
        drnet.train()
        for i, (images, labels, segments) in enumerate(trainloader):
            images = images.to(cuda0)
            labels = labels.to(cuda1)
            optimizer.zero_grad()
            region_support = rsnet(images)
            coarse_depth = torch.cat([images, region_support], 1)
            outputs = drnet(coarse_depth)
            #outputs=torch.reshape(outputs,[outputs.shape[0],1,outputs.shape[1],outputs.shape[2]])
            #outputs=outputs
            loss = loss_fn(input=outputs, target=labels)
            out = loss[0] + loss[1] + loss[2]
            # print('training:'+str(i)+':learning_rate'+str(loss.data.cpu().numpy()))
            out.backward()
            optimizer.step()
            # print(torch.Tensor([loss.data[0]]).unsqueeze(0).cpu())
            #print(loss.item()*torch.ones(1).cpu())
            #nyu2_train:246,nyu2_all:3265
            if args.visdom:
                vis.line(X=torch.ones(1).cpu() * i + torch.ones(1).cpu() *
                         (epoch - trained) * 3265,
                         Y=loss[0].item() * torch.ones(1).cpu(),
                         win=loss_window1,
                         update='append')
                vis.line(X=torch.ones(1).cpu() * i + torch.ones(1).cpu() *
                         (epoch - trained) * 3265,
                         Y=loss[1].item() * torch.ones(1).cpu(),
                         win=loss_window2,
                         update='append')
                vis.line(X=torch.ones(1).cpu() * i + torch.ones(1).cpu() *
                         (epoch - trained) * 3265,
                         Y=loss[2].item() * torch.ones(1).cpu(),
                         win=loss_window3,
                         update='append')
                pre = outputs[0].data.cpu().numpy().astype('float32')
                pre = pre[0, :, :]
                #pre = np.argmax(pre, 0)
                pre = (np.reshape(pre, [480, 640]).astype('float32') -
                       np.min(pre)) / (np.max(pre) - np.min(pre))
                #pre = pre/np.max(pre)
                # print(type(pre[0,0]))
                vis.image(
                    pre,
                    opts=dict(title='predict1!', caption='predict1.'),
                    win=pre_window1,
                )
                pre = outputs[1].data.cpu().numpy().astype('float32')
                pre = pre[0, :, :]
                #pre = np.argmax(pre, 0)
                pre = (np.reshape(pre, [480, 640]).astype('float32') -
                       np.min(pre)) / (np.max(pre) - np.min(pre))
                #pre = pre/np.max(pre)
                # print(type(pre[0,0]))
                vis.image(
                    pre,
                    opts=dict(title='predict2!', caption='predict2.'),
                    win=pre_window2,
                )
                pre = outputs[2].data.cpu().numpy().astype('float32')
                pre = pre[0, :, :]
                #pre = np.argmax(pre, 0)
                pre = (np.reshape(pre, [480, 640]).astype('float32') -
                       np.min(pre)) / (np.max(pre) - np.min(pre))
                #pre = pre/np.max(pre)
                # print(type(pre[0,0]))
                vis.image(
                    pre,
                    opts=dict(title='predict3!', caption='predict3.'),
                    win=pre_window3,
                )
                ground = labels.data.cpu().numpy().astype('float32')
                #print(ground.shape)
                ground = ground[0, :, :]
                ground = (np.reshape(ground, [480, 640]).astype('float32') -
                          np.min(ground)) / (np.max(ground) - np.min(ground))
                vis.image(
                    ground,
                    opts=dict(title='ground!', caption='ground.'),
                    win=ground_window,
                )

            loss_rec.append([
                i + epoch * 3265,
                torch.Tensor([loss[0].item()]).unsqueeze(0).cpu(),
                torch.Tensor([loss[1].item()]).unsqueeze(0).cpu(),
                torch.Tensor([loss[2].item()]).unsqueeze(0).cpu()
            ])
            print("data [%d/3265/%d/%d] Loss1: %.4f Loss2: %.4f Loss3: %.4f" %
                  (i, epoch, args.n_epoch, loss[0].item(), loss[1].item(),
                   loss[2].item()))

        #epoch=3

        if epoch % 3 == 0:
            print('testing!')
            rsnet.train()
            drnet.train()
            error_lin = []
            error_log = []
            error_va = []
            error_rate = []
            error_absrd = []
            error_squrd = []
            thre1 = []
            thre2 = []
            thre3 = []

            for i_val, (images, labels,
                        segments) in tqdm(enumerate(valloader)):
                print(r'\n')
                images = images.to(cuda0)
                labels = labels.to(cuda1)
                optimizer.zero_grad()

                with torch.no_grad():
                    region_support = rsnet(images)
                    coarse_depth = torch.cat([images, region_support],
                                             1).to(cuda1)
                    outputs = drnet(coarse_depth)
                    pred = outputs[2].data.cpu().numpy()
                    gt = labels.data.cpu().numpy()
                    ones = np.ones((gt.shape))
                    zeros = np.zeros((gt.shape))
                    pred = np.reshape(pred, (gt.shape))
                    #gt=np.reshape(gt,[4,480,640])
                    dis = np.square(gt - pred)
                    error_lin.append(np.sqrt(np.mean(dis)))
                    dis = np.square(np.log(gt) - np.log(pred))
                    error_log.append(np.sqrt(np.mean(dis)))
                    alpha = np.mean(np.log(gt) - np.log(pred))
                    dis = np.square(np.log(pred) - np.log(gt) + alpha)
                    error_va.append(np.mean(dis) / 2)
                    dis = np.mean(np.abs(gt - pred)) / gt
                    error_absrd.append(np.mean(dis))
                    dis = np.square(gt - pred) / gt
                    error_squrd.append(np.mean(dis))
                    thelt = np.where(pred / gt > gt / pred, pred / gt,
                                     gt / pred)
                    thres1 = 1.25

                    thre1.append(np.mean(np.where(thelt < thres1, ones,
                                                  zeros)))
                    thre2.append(
                        np.mean(np.where(thelt < thres1 * thres1, ones,
                                         zeros)))
                    thre3.append(
                        np.mean(
                            np.where(thelt < thres1 * thres1 * thres1, ones,
                                     zeros)))
                    #a=thre1[i_val]
                    #error_rate.append(np.mean(np.where(dis<0.6,ones,zeros)))
                    print(
                        "error_lin=%.4f,error_log=%.4f,error_va=%.4f,error_absrd=%.4f,error_squrd=%.4f,thre1=%.4f,thre2=%.4f,thre3=%.4f"
                        % (error_lin[i_val], error_log[i_val], error_va[i_val],
                           error_absrd[i_val], error_squrd[i_val],
                           thre1[i_val], thre2[i_val], thre3[i_val]))
            error = np.mean(error_lin)
            #error_rate=np.mean(error_rate)
            print("error=%.4f" % (error))

            if error <= best_error:
                best_error = error
                state = {
                    'epoch': epoch + 1,
                    'model_state': model.state_dict(),
                    'optimizer_state': optimizer.state_dict(),
                    'error': error,
                }
                torch.save(
                    state,
                    "{}_{}_best_model.pkl".format(args.arch, args.dataset))
                print('save success')
            np.save('/home/lidong/Documents/RSDEN/RSDEN//loss.npy', loss_rec)
        if epoch % 15 == 0:
            #best_error = error
            state = {
                'epoch': epoch + 1,
                'model_state': model.state_dict(),
                'optimizer_state': optimizer.state_dict(),
                'error': error,
            }
            torch.save(
                state, "{}_{}_{}_model.pkl".format(args.arch, args.dataset,
                                                   str(epoch)))
            print('save success')