Ejemplo n.º 1
0
def main():
    args = parser.parse_args()

    if os.path.isfile(args.out_h5data):
        raise Exception("Experiment name " + args.out_h5data + " already exists.")

    transforms = []
    # TODO Encode body points also differentially to some joint not only hand wrt wrist
    if args.dif_encoding:
        transforms.append(WristDifference())
        transforms.append(ChestDifference())
    # TODO Change Normalization scheme to fixed bone dist
    if args.normalize:
        transforms.append(NormalizeFixedFactor(1280))
    if args.predict == "right_index":
        n_input = 12 + 17
        n_output = 4
        transforms.append(BuildIndexItem())
    elif args.predict == "right_3fingers":
        n_input = 12 + 9
        n_output = 12
        transforms.append(Build3fingerItem())
    elif args.predict == "right_hand":
        # n_input = 12
        n_input = 8
        n_output = 21
        transforms.append(BuildRightHandItem())
    else:
        raise ValueError()

    transforms = torchvision.transforms.Compose(transforms)

    if "Text" in args.model:
        dataset = TextPoseH5Dataset(args.valid_h5data, args.valid_textdata, args.max_frames, transforms, selection=args.frames_selection,
                                      use_rand_tokens=args.rand_tokens)
    else:
        dataset = FastPoseDataset(args.data, args.max_frames, transforms)

    loader = DataLoader(dataset, batch_size=128, collate_fn=collate_function_h5)

    if args.model == "Conv":
        model = ConvModel(args.conv_channels, "ReLU", pos_emb=args.conv_pos_emb)
    elif args.model == "ConvTransformerEncoder":
        model = ConvTransformerEncoder(args, 21 * 2)
    elif args.model == "TransformerEnc":
        model = TransformerEnc(ninp=12*2, nhead=4, nhid=128, nout=21*2,
                               nlayers=4, dropout=args.transformer_dropout)
    elif args.model == "TextPoseTransformer":
        model = TextPoseTransformer(n_tokens=1000, n_joints=n_input, joints_dim=2, nhead=4,
                                    nhid=128, nout=n_output*2, n_enc_layers=4, n_dec_layers=4,
                                    dropout=args.transformer_dropout)
    else:
        raise ValueError()

    model.load_state_dict(torch.load(args.model_checkpoint))

    infer_utterance_h5(model, loader, args)
Ejemplo n.º 2
0
def main():
    args = parser.parse_args()

    if os.path.isdir(args.output_folder):
        raise Exception("Experiment name " + args.output_folder +
                        " already exists.")
    os.mkdir(args.output_folder)

    with open(args.output_folder + "/args.pckl", "wb") as f:
        pickle.dump(args, f)

    transform = None
    if args.normalize:
        transform = NormalizeFixedFactor(1280)
    # utterance_dict = build_dataset_structure(args.utterance_folder)
    # metadata_structure = [utterance_dict]

    if "Text" in args.model:
        dataset = FastTextPoseDataset(args.data,
                                      args.max_frames,
                                      transform,
                                      use_rand_tokens=args.rand_tokens)
    else:
        dataset = FastPoseDataset(args.data, args.max_frames, transform)

    loader = DataLoader(dataset, batch_size=1, collate_fn=collate_function)

    if args.model == "Conv":
        model = ConvModel(args.conv_channels,
                          activation="ReLU",
                          pos_emb=args.conv_pos_emb)
    elif args.model == "TransformerEncoder":
        model = TransformerEncoder(args, 100)
    elif args.model == "ConvTransformerEncoder":
        model = ConvTransformerEncoder(args, 21 * 2)
    elif args.model == "TransformerEnc":
        model = TransformerEnc(ninp=12 * 2,
                               nhead=4,
                               nhid=100,
                               nout=21 * 2,
                               nlayers=4,
                               dropout=0.0)
    elif args.model == "TextPoseTransformer":
        model = TextPoseTransformer(n_tokens=1000,
                                    n_joints=12,
                                    joints_dim=2,
                                    nhead=4,
                                    nhid=128,
                                    nout=21 * 2,
                                    n_enc_layers=4,
                                    n_dec_layers=4,
                                    dropout=args.transformer_dropout)
    else:
        raise ValueError()
    model.load_state_dict(torch.load(args.model_checkpoint))

    infer_utterance(model, loader, args)
Ejemplo n.º 3
0
if __name__ == '__main__':

    args = parser.parse_args()
    print(args)
    if args.resume:
        assert (bool(args.exp))
        with open("%s/args.pckl" % args.exp, "rb") as f:
            args = pickle.load(f)
            args.resume = True
            pass

    transform = None
    # TODO Change Normalization scheme to fixed bone dist
    if args.normalize:
        transform = NormalizeFixedFactor(1280)

    train_dataset = FastPoseDataset(args.train_data, args.max_frames,
                                    transform)
    valid_dataset = FastPoseDataset(args.valid_data, args.max_frames,
                                    transform)

    train_dataloader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  collate_fn=collate_function)
    valid_dataloader = DataLoader(valid_dataset,
                                  batch_size=args.batch_size,
                                  collate_fn=collate_function)

    if args.model == "Conv":
        model = ConvModel(args.conv_channels,
Ejemplo n.º 4
0
        with open("%s/args.pckl" % args.exp, "rb") as f:
            args = pickle.load(f)
            args.resume = True
            # try:
            #     _ = args.rand_tokens
            # except:
            #     args.rand_tokens = True

    transforms = []
    # TODO Encode body points also differentially to some joint not only hand wrt wrist
    if args.dif_encoding:
        transforms.append(WristDifference())
        transforms.append(ChestDifference())
    # TODO Change Normalization scheme to fixed bone dist
    if args.normalize:
        transforms.append(NormalizeFixedFactor(1280))
    if args.predict == "right_index":
        n_input = 12 + 17
        n_output = 4
        transforms.append(BuildIndexItem())
    elif args.predict == "right_3fingers":
        n_input = 12 + 9
        n_output = 12
        transforms.append(Build3fingerItem())
    elif args.predict == "right_hand":
        n_input = 8
        n_output = 21
        transforms.append(BuildRightHandItem())
    else:
        raise ValueError()