Exemplo n.º 1
0
# ================CREATE NETWORK============================ #
model = PoseEstimator(shape=opt.shape,
                      shape_feature_dim=opt.shape_feature_dim,
                      img_feature_dim=opt.img_feature_dim,
                      azi_classes=opt.azi_classes,
                      ele_classes=opt.ele_classes,
                      inp_classes=opt.inp_classes,
                      render_number=opt.num_render,
                      separate_branch=opt.separate_branch,
                      channels=opt.channels)
model.cuda()
if opt.model is not None:
    checkpoint = torch.load(opt.model,
                            map_location=lambda storage, loc: storage.cuda())
    pretrained_dict = checkpoint['state_dict']
    model_dict = model.state_dict()
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if k in model_dict
    }
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)
    print('Previous weight loaded')
model.eval()
# ========================================================== #

# ==================INPUT IMAGE AND RENDERS================= #
# define data preprocessing for real images in validation
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
data_validating = transforms.Compose([transforms.ToTensor(), normalize])
Exemplo n.º 2
0
# =============BEGIN OF THE LEARNING LOOP=================== #
# initialization
best_acc = 0.

for epoch in range(opt.n_epoch):
    # update learning rate
    lrScheduler.step()

    # train
    train_loss, train_acc_rot = train(train_loader, model, opt.bin_size,
                                      criterion_azi, criterion_ele, criterion_inp, criterion_reg, optimizer)

    # save checkpoint
    torch.save({'epoch': epoch,
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'losses': losses},
               os.path.join(result_path, 'checkpoint.pth'))

    # save losses and accuracies into log file
    with open(logname, 'a') as f:
        text = str('Epoch: %03d || train_loss %.2f || train_acc %.2f \n \n' % (epoch, train_loss, train_acc_rot))
        f.write(text)
# ========================================================== #


import pickle
import collections

data_loader = DataLoader(dataset_train, batch_size=1, shuffle=False, num_workers=1)