示例#1
0
def load_model(model):
    # build model
    net = get_pose_net()
    if torch.cuda.is_available():
        net = net.cuda()
        # net = nn.DataParallel(net)  # multi-Gpu

    save_path = os.path.join('ckpt2/ucihand_lstm_pm' + str(model)+'.pth')
    state_dict = torch.load(save_path)
    net.load_state_dict(state_dict)
    return net
示例#2
0
    os.mkdir(save_dir)

transform = transforms.Compose([transforms.ToTensor()])

# Build dataset
train_data = Dhp19PoseDataset(data_dir=train_data_dir,
                              label_dir=train_label_dir,
                              temporal=temporal,
                              train=True)
print('Train dataset total number of images sequence is ----' +
      str(len(train_data)))

# Data Loader
train_dataset = DataLoader(train_data, batch_size=batch_size, shuffle=False)

net = get_pose_net()
#gpus = [int(i) for i in '2'.split(',')]
model = torch.nn.DataParallel(net, device_ids=[0]).cuda()


def train():
    criterion = JointsMSELoss(
        use_target_weight=config.LOSS.USE_TARGET_WEIGHT).cuda()
    optimizer = get_optimizer(config, model)
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, config.TRAIN.LR_STEP, config.TRAIN.LR_FACTOR)
    for epoch in range(begin_epoch, epochs + 1):
        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        acc = AverageMeter()
示例#3
0
# vgg11 = models.vgg11().cuda()
# summary(vgg11,(3,224,224))

# vgg16 = models.vgg16().cuda()
# summary(vgg16,(3,224,224))

# vgg19 = models.vgg19().cuda()
# summary(vgg19,(3,224,224))

# alexnet = models.alexnet().cuda()
# summary(alexnet,(3,224,224))

# resnet18 = models.resnet18().cuda()
# summary(resnet18,(3,224,224))

# resnet34= models.resnet34().cuda()
# summary(resnet34,(3,224,224))

# from torch import  nn
# lstm = nn.LSTM(4,10).cuda()
# print(lstm.)
# nn.

# net = get_pose_net().cuda()
# summary(net,(1, 368, 368))
# print(net)
import torch
net = get_pose_net().cuda()
# a = torch.randn(1, 3, 224, 224).cuda()  # batch size = 2
a = torch.randn(1, 1, 224, 224).cuda()  # batch size = 2
net(a)