Exemple #1
0
def test_test_real():
    args = parse()
    print(args)
    dataset = SepeDataset(args.poses_train,args.images_train,coor_layer_flag =False)
    dataloader = DataLoader(dataset, batch_size=10,shuffle=False ,num_workers=1,drop_last=True)
    dvo_feature_extractor = DVOFeature()
    dvo_regressor         = DVORegression()
    dvo_discriminator     = Discriminator(500,500,2)
    dvo_feature_extractor.load_state_dict(torch.load(('feature_ntsd_2_10.pt')))
    dvo_regressor.load_state_dict(torch.load('regressor_seed_ntsd_2_10.pt'))
    test(dvo_feature_extractor,dvo_regressor,dataloader,args)
Exemple #2
0
def test_test():
    args = parse()
    print(args)
    motion_ax_i = [int(i) for i in args.motion_ax.split(' ')]
    test_motion_ax_i = [int(i) for i in args.test_motion_ax.split(' ')]
    dataset = RandomDataset(2,motion_ax = test_motion_ax_i)
    dataloader = DataLoader(dataset, batch_size=1,shuffle=False ,num_workers=1,drop_last=True)
    dvo_feature_extractor = DVOFeature()
    dvo_regressor         = DVORegression()
    dvo_discriminator     = Discriminator(500,500,2)
    dvo_feature_extractor.load_state_dict(torch.load('feature'+args.motion_ax.replace(' ','')+'.pt'))
    dvo_regressor.load_state_dict(torch.load('regressor'+args.motion_ax.replace(' ','')+'.pt'))
    test(dvo_feature_extractor,dvo_regressor,dataloader,args)
Exemple #3
0
def test_adapt():
    args = parse()
    print(args)
    dataset = SepeDataset(args.poses_train,args.images_train,coor_layer_flag =False)
    dataloader = DataLoader(dataset, batch_size=1,shuffle=True ,num_workers=1,drop_last=True,worker_init_fn=lambda wid:np.random.seed(np.uint32(torch.initial_seed() + wid)))
    dataset_tgt = SepeDataset(args.poses_target,args.images_target,coor_layer_flag =False)
    dataloader_tgt = DataLoader(dataset_tgt, batch_size=1,shuffle=True ,num_workers=1,drop_last=True,worker_init_fn=lambda wid:np.random.seed(np.uint32(torch.initial_seed() + wid)))
    src_extractor = DVOFeature()
    tgt_extractor = DVOFeature()
    src_extractor.load_state_dict(torch.load(args.feature_model))
    tgt_extractor.load_state_dict(torch.load(args.feature_model))
    dvo_discriminator     = Discriminator(500,500,2)
    adapt(src_extractor,tgt_extractor,dvo_discriminator,dataloader,dataloader_tgt,args)
    torch.save(tgt_extractor.state_dict(),'tgt_feature_'+args.tag+str(args.epoch)+'.pt')
    torch.save(dvo_discriminator.state_dict(),'dis_'+args.tag+str(args.epoch)+'.pt')
Exemple #4
0
def test_model(image):
    dvo_feature_extractor = DVOFeature()
    dvo_regressor         = DVORegression()
    dvo_discriminator     = Discriminator(500,500,2)
    feature = dvo_feature_extractor(image)
    print(feature.shape)
    motion  = dvo_regressor(feature)
    print(motion.shape)
    dis     = dvo_discriminator(feature)
    print(dis)
Exemple #5
0
def test_train_real():
    args = parse()
    print(args)
    dataset = SepeDataset(args.poses_train,args.images_train,coor_layer_flag =False)
    dataloader = DataLoader(dataset, batch_size=3,shuffle=True ,num_workers=1,drop_last=True,worker_init_fn=lambda wid:np.random.seed(np.uint32(torch.initial_seed() + wid)))
    dvo_feature_extractor = DVOFeature()
    dvo_regressor         = DVORegression()
    dvo_discriminator     = Discriminator(500,500,2)
    trained_feature,trained_regressor = train(dvo_feature_extractor,dvo_regressor,dataloader,args)
    torch.save(trained_feature.state_dict(),'feature_'+args.tag+str(args.epoch)+'.pt')
    torch.save(trained_regressor.state_dict(),'regressor_'+args.tag+str(args.epoch)+'.pt')
Exemple #6
0
def test_train():
    args = parse()
    print(args)
    motion_ax_i = [int(i) for i in args.motion_ax.split(' ')]
    dataset = RandomDataset(20000,motion_ax = motion_ax_i)
    dataloader = DataLoader(dataset, batch_size=1000,shuffle=False ,num_workers=1,drop_last=True,worker_init_fn=lambda wid:np.random.seed(np.uint32(torch.initial_seed() + wid)))
    dvo_feature_extractor = DVOFeature()
    dvo_regressor         = DVORegression()
    dvo_discriminator     = Discriminator(500,500,2)
    trained_feature,trained_regressor = train(dvo_feature_extractor,dvo_regressor,dataloader,args)
    torch.save(trained_feature.state_dict(),'feature_seed'+args.motion_ax.replace(' ','')+str(args.epoch)+'.pt')
    torch.save(trained_regressor.state_dict(),'regressor_seed'+args.motion_ax.replace(' ','')+str(args.epoch)+'.pt')