示例#1
0
def test(root_dir, weight_path):
    print('This is the programme of testing.')
    BATCH_SIZE = 32
    # load model
    print('Building Network...')
    net = Net()
    net.cuda()
    pos_error = MyPosEvaluation()
    pos_error.cuda()
    cur_error = MyCurEvaluation()
    cur_error.cuda()
    print('Loading Network...')
    net.load_state_dict(torch.load(weight_path))
    net.eval()
    print('Loading Dataset...')
    test_data = HairNetDataset(project_dir=root_dir,train_flag=0,noise_flag=0)
    test_loader = DataLoader(dataset=test_data, batch_size=BATCH_SIZE)
    # load testing data
    print('Testing...')
    for i, data in enumerate(test_loader, 0):
        img, convdata, visweight = data
        img = img.cuda()
        convdata = convdata.cuda()
        visweight = visweight.cuda()
        output = net(img)
        pos = pos_error(output, convdata)
        cur = cur_error(output, convdata)
        print(str(BATCH_SIZE*(i+1)) + '/' + str(len(test_data)) + ', Position loss: ' + str(pos.item()) + ', Curvature loss: ' + str(cur.item()))
示例#2
0
文件: model.py 项目: givenone/HairNet
def demo(root_dir, weight_path):
    print('This is the programme of demo.')
    BATCH_SIZE = 1
    # load model
    print('Building Network...')
    summary(Net().cuda(), (3, 128, 128))
    recon_net = ReconNet(3)
    recon_net.cuda()
    recon_net.load_state_dict(torch.load(weight_path))
    recon_net.eval()
    test_data = HairNetDataset(project_dir=root_dir,
                               train_flag=0,
                               noise_flag=0)
    test_loader = DataLoader(dataset=test_data, batch_size=BATCH_SIZE)
    import cv2
    import numpy as np
    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d import Axes3D
    # load testing data
    for i, data in enumerate(test_loader, 0):
        img, _, _ = data

        cv2.imshow('', img[0].numpy().T)  # input orientation

        img = img.cuda()
        output = recon_net(img)
        strands = output[0].cpu().detach().numpy()  # hair strands
        with open('demo/demo.convdata', 'wb') as wf:
            np.save(wf, strands)
        print(np.swapaxes(strands[:, :3, :, :], 0, 1).shape)
        hair_pos = np.swapaxes(
            np.swapaxes(strands[:, :3, :, :], 0, 1).reshape(3, -1), 0, 1)
        print(hair_pos.shape)
        with open('demo/demo.txt', 'w') as wf:
            np.savetxt(wf, hair_pos)

        cv2.waitKey(0)
        cv2.destroyAllWindows()
        break
示例#3
0
def train(root_dir):
    print('This is the programme of training.')
    # build model
    print('Initializing Network...')
    net = Net()
    net.cuda()
    loss = MyLoss()
    loss.cuda()
    # set hyperparameter
    EPOCH = 100
    BATCH_SIZE = 32
    LR = 1e-4
    # set parameter of log
    PRINT_STEP = 10 # batch
    LOG_STEP = 100 # batch
    WEIGHT_STEP = 5 # epoch
    LR_STEP = 10 # change learning rate
    # load data
    print('Setting Dataset and DataLoader...')
    train_data = HairNetDataset(project_dir=root_dir,train_flag=1,noise_flag=1)
    train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE)
    # set optimizer    
    optimizer = optim.Adam(net.parameters(), lr=LR)
    loss_list = []
    print('Training...')
    for i in range(EPOCH):
        epoch_loss = 0.0
        # change learning rate
        if (i+1)%LR_STEP == 0:
            for param_group in optimizer.param_groups:
                current_lr = param_group['lr']
                param_group['lr'] = current_lr * 0.5
        for j, data in enumerate(train_loader, 0):
            img, convdata, visweight = data
            img = img.cuda()
            convdata = convdata.cuda()
            visweight = visweight.cuda()
            # img (batch_size, 3, 256, 256)     
            # convdata (batch_size, 100, 4, 32, 32)
            # visweight (batch_size, 100, 32, 32)

            # zero the parameter gradients
            optimizer.zero_grad()
            output = net(img) #img (batch_size, 100, 4, 32, 32)
            my_loss = loss(output, convdata, visweight)
            epoch_loss += my_loss.item()
            my_loss.backward()
            optimizer.step()
            if (j+1)%PRINT_STEP == 0:
                print('epoch: ' + str(i+1) + ', ' + str(BATCH_SIZE*(j+1)) + '/' + str(len(train_data)) + ', loss: ' + str(my_loss.item()))
            if (j+1)%LOG_STEP == 0:
                if not os.path.exists(root_dir+'/log.txt'):
                    with open(root_dir+'/log.txt', 'w') as f:
                        f.write('epoch: ' + str(i+1) + ', ' + str(BATCH_SIZE*(j+1)) + '/' + str(len(train_data)) + ', loss: ' + str(my_loss.item()) + '\n')    
                else:
                    with open(root_dir+'/log.txt', 'a') as f:
                        f.write('epoch: ' + str(i+1) + ', ' + str(BATCH_SIZE*(j+1)) + '/' + str(len(train_data)) + ', loss: ' + str(my_loss.item()) + '\n')        
        if (i+1)%WEIGHT_STEP == 0:       
            save_path = root_dir + '/weight/' + str(i+1).zfill(6) + '_weight.pt'
            torch.save(net.state_dict(), save_path)
        loss_list.append(epoch_loss)
    print('Finish...')
    plt.plot(loss_list)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.savefig(root_dir + 'loss.png')
示例#4
0
文件: model.py 项目: givenone/HairNet
def train(root_dir, load_epoch=None):
    print('This is the programme of training.')

    log_path = root_dir + '/log.txt'
    loss_pic_path = root_dir + '/loss.png'
    weight_save_path = root_dir + '/weight/'
    debug_weight_save_path = root_dir + '/debug/'
    debug_log_path = root_dir + '/debug/log.txt'
    debug_loss_pic_path = root_dir + '/debug/loss.png'

    # build model
    print('Initializing Network...')
    net = Net()
    # print(net)
    net.cuda()
    # summary(net, (3, 128,128))
    # return
    loss = MyLoss()
    loss.cuda()
    # load weight if possible
    start_epoch = 0
    if load_epoch != None:
        weight_load_path = 'weight/{}_weight.pt'.format(load_epoch)
        net.load_state_dict(torch.load(weight_load_path))
        start_epoch = int(load_epoch)
        print("start epoch:", start_epoch + 1)
    # set hyperparameter
    EPOCH = 500
    BATCH_SIZE = 100
    LR = 1e-4
    # set parameter of log
    PRINT_STEP = 10  # batch
    LOG_STEP = 100  # batch
    WEIGHT_STEP = 1  # epoch
    LR_STEP = 250  # change learning rate
    # load data
    print('Setting Dataset and DataLoader...')
    train_data = HairNetDataset(project_dir=root_dir,
                                train_flag=1,
                                noise_flag=1)
    train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE)
    # set optimizer
    optimizer = optim.Adam(net.parameters(), lr=LR)
    loss_list = []
    print('Training...')
    for i in range(start_epoch, EPOCH):
        epoch_loss = 0.0
        # change learning rate
        if (i + 1) % LR_STEP == 0 and i != 0:
            for param_group in optimizer.param_groups:
                current_lr = param_group['lr']
                param_group['lr'] = current_lr * 0.5
        for j, data in enumerate(train_loader, 0):
            img, convdata, visweight = data
            img = img.cuda()
            convdata = convdata.cuda()
            visweight = visweight.cuda()
            # img (batch_size, 3, 128, 128)
            # convdata (batch_size, 100, 4, 32, 32)
            # visweight (batch_size, 100, 32, 32)

            # zero the parameter gradients
            optimizer.zero_grad()
            output = net(img)  #img (batch_size, 100, 4, 32, 32)
            my_loss = loss(output, convdata, visweight)

            # debug
            if i == 0 and j == 0:
                if not os.path.exists(debug_log_path):
                    with open(debug_log_path, 'w') as f:
                        f.write('epoch: ' + str(i + 1) + ', ' +
                                str(BATCH_SIZE *
                                    (j + 1)) + '/' + str(len(train_data)) +
                                ', loss: ' + str(my_loss.item()) + '\n')
                        print('Debug of writing log.txt!')
                else:
                    with open(debug_log_path, 'a') as f:
                        f.write('epoch: ' + str(i + 1) + ', ' +
                                str(BATCH_SIZE *
                                    (j + 1)) + '/' + str(len(train_data)) +
                                ', loss: ' + str(my_loss.item()) + '\n')
                        print('Debug of writing log.txt!')
                save_path = debug_weight_save_path + 'weight.pt'
                torch.save(net.state_dict(), save_path)
                print('Debug of saving model!')
                debug_loss_list = []
                debug_loss_list.append(my_loss.item())
                plt.plot(debug_loss_list)
                plt.xlabel('Epoch')
                plt.ylabel('Loss')
                plt.xlim(0, EPOCH - 1)
                plt.savefig(debug_loss_pic_path)
                print('Debug of drawing loss picture!')

            epoch_loss += my_loss.item()
            my_loss.backward()
            optimizer.step()
            if (j + 1) % PRINT_STEP == 0:
                print('epoch: ' + str(i + 1) + ', ' + str(BATCH_SIZE *
                                                          (j + 1)) + '/' +
                      str(len(train_data)) + ', loss: ' + str(my_loss.item()))
            if (j + 1) % LOG_STEP == 0:
                if not os.path.exists(log_path):
                    with open(log_path, 'w') as f:
                        f.write('epoch: ' + str(i + 1) + ', ' +
                                str(BATCH_SIZE *
                                    (j + 1)) + '/' + str(len(train_data)) +
                                ', loss: ' + str(my_loss.item()) + '\n')
                else:
                    with open(log_path, 'a') as f:
                        f.write('epoch: ' + str(i + 1) + ', ' +
                                str(BATCH_SIZE *
                                    (j + 1)) + '/' + str(len(train_data)) +
                                ', loss: ' + str(my_loss.item()) + '\n')
        if (i + 1) % WEIGHT_STEP == 0:
            save_path = weight_save_path + str(i + 1).zfill(6) + '_weight.pt'
            torch.save(net.state_dict(), save_path)
        loss_list.append(epoch_loss)
    print('Finish...')
    plt.plot(loss_list)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.xlim(0, EPOCH - 1)
    plt.savefig(loss_pic_path)