Exemplo n.º 1
0
def load_ckp(ckp_path, model):
  state = torch.load(ckp_path)
  model.load_state_dict(state['state_dict'])
  print("model load from %s" % ckp_path)

if __name__ == "__main__":
  torch.manual_seed(SEED)
  device = torch.device(f'cuda:{gpus[0]}' if torch.cuda.is_available() else 'cpu')
  print("Loading test dataset...")
  test_data = PointNetDataset("./dataset/modelnet40_normal_resampled", train=1)
  test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)
  model = PointNet().to(device=device)
  if ckp_path:
    load_ckp(ckp_path, model)
    model = model.to(device)
  
  model.eval()

  with torch.no_grad():
    accs = []
    gt_ys = []
    pred_ys = []
    for x, y in test_loader:
      x = x.to(device)
      y = y.to(device)

      # TODO: put x into network and get out
      out =

      # TODO: get pred_y from out
Exemplo n.º 2
0
def train(args, io):
    train_loader = DataLoader(ModelNet40(partition='train',
                                         num_points=args.num_points),
                              num_workers=8,
                              batch_size=args.batch_size,
                              shuffle=True,
                              drop_last=True)
    test_loader = DataLoader(ModelNet40(partition='test',
                                        num_points=args.num_points),
                             num_workers=8,
                             batch_size=args.test_batch_size,
                             shuffle=True,
                             drop_last=False)

    device = torch.device("cuda" if args.cuda else "cpu")

    #Try to load models
    if args.model == 'pointnet':
        model = PointNet(args).to(device)
    elif args.model == 'dgcnn':
        model = DGCNN(args).to(device)
    elif args.model == 'ssg':
        model = PointNet2SSG(output_classes=40, dropout_prob=args.dropout)
        model.to(device)
    elif args.model == 'msg':
        model = PointNet2MSG(output_classes=40, dropout_prob=args.dropout)
        model.to(device)
    elif args.model == 'ognet':
        # [64,128,256,512]
        model = Model_dense(20,
                            args.feature_dims, [512],
                            output_classes=40,
                            init_points=768,
                            input_dims=3,
                            dropout_prob=args.dropout,
                            id_skip=args.id_skip,
                            drop_connect_rate=args.drop_connect_rate,
                            cluster='xyzrgb',
                            pre_act=args.pre_act,
                            norm=args.norm_layer)
        if args.efficient:
            model = ModelE_dense(20,
                                 args.feature_dims, [512],
                                 output_classes=40,
                                 init_points=768,
                                 input_dims=3,
                                 dropout_prob=args.dropout,
                                 id_skip=args.id_skip,
                                 drop_connect_rate=args.drop_connect_rate,
                                 cluster='xyzrgb',
                                 pre_act=args.pre_act,
                                 norm=args.norm_layer,
                                 gem=args.gem,
                                 ASPP=args.ASPP)
        model.to(device)
    elif args.model == 'ognet-small':
        # [48,96,192,384]
        model = Model_dense(20,
                            args.feature_dims, [512],
                            output_classes=40,
                            init_points=768,
                            input_dims=3,
                            dropout_prob=args.dropout,
                            id_skip=args.id_skip,
                            drop_connect_rate=args.drop_connect_rate,
                            cluster='xyzrgb',
                            pre_act=args.pre_act,
                            norm=args.norm_layer)
        model.to(device)
    else:
        raise Exception("Not implemented")
    print(str(model))

    model = nn.DataParallel(model)
    print("Let's use", torch.cuda.device_count(), "GPUs!")

    if args.use_sgd:
        print("Use SGD")
        opt = optim.SGD(model.parameters(),
                        lr=args.lr * 100,
                        momentum=args.momentum,
                        weight_decay=1e-4)
        scheduler = CosineAnnealingLR(opt, args.epochs, eta_min=args.lr)
    else:
        print("Use Adam")
        opt = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-4)
        scheduler = CosineAnnealingLR(opt, args.epochs, eta_min=0.01 * args.lr)

    criterion = cal_loss

    best_test_acc = 0
    best_avg_per_class_acc = 0

    warm_up = 0.1  # We start from the 0.1*lrRate
    warm_iteration = round(
        len(ModelNet40(partition='train', num_points=args.num_points)) /
        args.batch_size) * args.warm_epoch  # first 5 epoch
    for epoch in range(args.epochs):
        scheduler.step()
        ####################
        # Train
        ####################
        train_loss = 0.0
        count = 0.0
        model.train()
        train_pred = []
        train_true = []
        for data, label in train_loader:
            data, label = data.to(device), label.to(device).squeeze()
            batch_size = data.size()[0]
            opt.zero_grad()
            if args.model == 'ognet' or args.model == 'ognet-small' or args.model == 'ssg' or args.model == 'msg':
                logits = model(data, data)
            else:
                data = data.permute(0, 2, 1)
                logits = model(data)
            loss = criterion(logits, label)
            if epoch < args.warm_epoch:
                warm_up = min(1.0, warm_up + 0.9 / warm_iteration)
                loss *= warm_up
            loss.backward()
            opt.step()
            preds = logits.max(dim=1)[1]
            count += batch_size
            train_loss += loss.item() * batch_size
            train_true.append(label.cpu().numpy())
            train_pred.append(preds.detach().cpu().numpy())
        train_true = np.concatenate(train_true)
        train_pred = np.concatenate(train_pred)
        outstr = 'Train %d, loss: %.6f, train acc: %.6f, train avg acc: %.6f' % (
            epoch, train_loss * 1.0 / count,
            metrics.accuracy_score(train_true, train_pred),
            metrics.balanced_accuracy_score(train_true, train_pred))
        io.cprint(outstr)

        ####################
        # Test
        ####################
        test_loss = 0.0
        count = 0.0
        model.eval()
        test_pred = []
        test_true = []
        for data, label in test_loader:
            data, label = data.to(device), label.to(device).squeeze()
            batch_size = data.size()[0]
            if args.model == 'ognet' or args.model == 'ognet-small' or args.model == 'ssg' or args.model == 'msg':
                logits = model(data, data)
            else:
                data = data.permute(0, 2, 1)
                logits = model(data)
            loss = criterion(logits, label)
            preds = logits.max(dim=1)[1]
            count += batch_size
            test_loss += loss.item() * batch_size
            test_true.append(label.cpu().numpy())
            test_pred.append(preds.detach().cpu().numpy())
        test_true = np.concatenate(test_true)
        test_pred = np.concatenate(test_pred)
        test_acc = metrics.accuracy_score(test_true, test_pred)
        avg_per_class_acc = metrics.balanced_accuracy_score(
            test_true, test_pred)
        outstr = 'Test %d, loss: %.6f, test acc: %.6f, test avg acc: %.6f' % (
            epoch, test_loss * 1.0 / count, test_acc, avg_per_class_acc)
        io.cprint(outstr)
        if test_acc + avg_per_class_acc >= best_test_acc + best_avg_per_class_acc:
            best_test_acc = test_acc
            best_avg_per_class_acc = avg_per_class_acc
            print('This is the current best.')
            torch.save(model.state_dict(),
                       'checkpoints/%s/models/model.t7' % args.exp_name)
Exemplo n.º 3
0
def train(modelin=args.model, modelout=args.out,device=args.device,opt=args.opt):

    # define model, dataloader, 3dmm eigenvectors, optimization method
    calib_net = PointNet(n=1)
    sfm_net = PointNet(n=199)
    if modelin != "":
        calib_path = os.path.join('model','calib_' + modelin)
        sfm_path = os.path.join('model','sfm_' + modelin)
        pretrained1 = torch.load(calib_path)
        pretrained2 = torch.load(sfm_path)
        calib_dict = calib_net.state_dict()
        sfm_dict = sfm_net.state_dict()

        pretrained1 = {k: v for k,v in pretrained1.items() if k in calib_dict}
        pretrained2 = {k: v for k,v in pretrained2.items() if k in sfm_dict}
        calib_dict.update(pretrained1)
        sfm_dict.update(pretrained2)

        calib_net.load_state_dict(pretrained1)
        sfm_net.load_state_dict(pretrained2)

    calib_net.to(device=device)
    sfm_net.to(device=device)
    opt1 = torch.optim.Adam(calib_net.parameters(),lr=1e-3)
    opt2 = torch.optim.Adam(sfm_net.parameters(),lr=1e-3)

    # dataloader
    data = dataloader.Data()
    loader = data.batchloader
    batch_size = data.batchsize

    # mean shape and eigenvectors for 3dmm
    mu_lm = torch.from_numpy(data.mu_lm).float()#.to(device=device)
    mu_lm[:,2] = mu_lm[:,2] * -1
    mu_lm = torch.stack(batch_size * [mu_lm.to(device=device)])
    shape = mu_lm
    lm_eigenvec = torch.from_numpy(data.lm_eigenvec).float().to(device=device)
    lm_eigenvec = torch.stack(batch_size * [lm_eigenvec])

    M = data.M
    N = data.N

    # main training loop
    for epoch in itertools.count():
        for j,batch in enumerate(loader):

            # get the input and gt values
            x_cam_gt = batch['x_cam_gt'].to(device=device)
            shape_gt = batch['x_w_gt'].to(device=device)
            fgt = batch['f_gt'].to(device=device)
            x_img = batch['x_img'].to(device=device)
            #beta_gt = batch['beta_gt'].to(device=device)
            #x_img_norm = batch['x_img_norm']
            x_img_gt = batch['x_img_gt'].to(device=device).permute(0,2,1,3)
            batch_size = fgt.shape[0]

            one = torch.ones(batch_size,M*N,1).to(device=device)
            x_img_one = torch.cat([x_img,one],dim=2)
            x_cam_pt = x_cam_gt.permute(0,1,3,2).reshape(batch_size,6800,3)
            x = x_img.permute(0,2,1)
            #x = x_img.permute(0,2,1).reshape(batch_size,2,M,N)

            ptsI = x_img_one.reshape(batch_size,M,N,3).permute(0,1,3,2)[:,:,:2,:]

            # if just optimizing
            if not opt:
                # calibration
                f = calib_net(x) + 300
                K = torch.zeros((batch_size,3,3)).float().to(device=device)
                K[:,0,0] = f.squeeze()
                K[:,1,1] = f.squeeze()
                K[:,2,2] = 1

                # sfm
                betas = sfm_net(x)
                betas = betas.unsqueeze(-1)
                shape = mu_lm + torch.bmm(lm_eigenvec,betas).squeeze().view(batch_size,N,3)

                opt1.zero_grad()
                opt2.zero_grad()
                f_error = torch.mean(torch.abs(f - fgt))
                #error2d = torch.mean(torch.abs(pred - x_img_gt))
                error3d = torch.mean(torch.abs(shape - shape_gt))
                error = f_error + error3d
                error.backward()
                opt1.step()
                opt2.step()

                print(f"f_error: {f_error.item():.3f} | error3d: {error3d.item():.3f} | f/fgt: {f[0].item():.1f}/{fgt[0].item():.1f} | f/fgt: {f[1].item():.1f}/{fgt[1].item():.1f} | f/fgt: {f[2].item():.1f}/{fgt[2].item():.1f} | f/fgt: {f[3].item():.1f}/{fgt[3].item():.1f} ")
                continue

            # get shape error from image projection
            print(f"f/fgt: {f[0].item():.3f}/{fgt[0].item():.3f} | rmse: {rmse:.3f} | f_rel: {f_error.item():.4f}  | loss1: {loss1.item():.3f} | loss2: {loss2.item():.3f}")

        # save model and increment weight decay
        print("saving!")
        torch.save(sfm_net.state_dict(), os.path.join('model','sfm_'+modelout))
        torch.save(calib_net.state_dict(), os.path.join('model','calib_'+modelout))
        test(modelin=args.out,outfile=args.out,optimize=False)
Exemplo n.º 4
0
def test(args, io):
    test_loader = DataLoader(ModelNet40(partition='test',
                                        num_points=args.num_points),
                             batch_size=args.test_batch_size,
                             shuffle=True,
                             drop_last=False)

    device = torch.device("cuda" if args.cuda else "cpu")

    #Try to load models
    if args.model == 'pointnet':
        model = PointNet(args).to(device)
    elif args.model == 'dgcnn':
        model = DGCNN(args).to(device)
    elif args.model == 'ssg':
        model = PointNet2SSG(output_classes=40, dropout_prob=0)
        model.to(device)
    elif args.model == 'msg':
        model = PointNet2MSG(output_classes=40, dropout_prob=0)
        model.to(device)
    elif args.model == 'ognet':
        # [64,128,256,512]
        model = Model_dense(20,
                            args.feature_dims, [512],
                            output_classes=40,
                            init_points=768,
                            input_dims=3,
                            dropout_prob=args.dropout,
                            id_skip=args.id_skip,
                            drop_connect_rate=args.drop_connect_rate,
                            cluster='xyzrgb',
                            pre_act=args.pre_act,
                            norm=args.norm_layer)
        if args.efficient:
            model = ModelE_dense(20,
                                 args.feature_dims, [512],
                                 output_classes=40,
                                 init_points=768,
                                 input_dims=3,
                                 dropout_prob=args.dropout,
                                 id_skip=args.id_skip,
                                 drop_connect_rate=args.drop_connect_rate,
                                 cluster='xyzrgb',
                                 pre_act=args.pre_act,
                                 norm=args.norm_layer,
                                 gem=args.gem,
                                 ASPP=args.ASPP)
        model.to(device)
    elif args.model == 'ognet-small':
        # [48,96,192,384]
        model = Model_dense(20,
                            args.feature_dims, [512],
                            output_classes=40,
                            init_points=768,
                            input_dims=3,
                            dropout_prob=args.dropout,
                            id_skip=args.id_skip,
                            drop_connect_rate=args.drop_connect_rate,
                            cluster='xyzrgb',
                            pre_act=args.pre_act,
                            norm=args.norm_layer)
        model.to(device)
    else:
        raise Exception("Not implemented")

    try:
        model.load_state_dict(torch.load(args.model_path))
    except:
        model = nn.DataParallel(model)
        model.load_state_dict(torch.load(args.model_path))
    model = model.eval()
    model = model.module

    batch0, label0 = next(iter(test_loader))
    batch0 = batch0[0].unsqueeze(0)
    print(batch0.shape)
    print(model)

    macs, params = get_model_complexity_info(model,
                                             batch0, ((1024, 3)),
                                             as_strings=True,
                                             print_per_layer_stat=False,
                                             verbose=True)

    print('{:<30}  {:<8}'.format('Computational complexity: ', macs))
    print('{:<30}  {:<8}'.format('Number of parameters: ', params))

    test_acc = 0.0
    count = 0.0
    test_true = []
    test_pred = []
    for data, label in test_loader:

        data, label = data.to(device), label.to(device).squeeze()
        batch_size = data.size()[0]
        if args.model == 'ognet' or args.model == 'ognet-small' or args.model == 'ssg' or args.model == 'msg':
            logits = model(data, data)
            #logits = model(1.1*data, 1.1*data)
        else:
            data = data.permute(0, 2, 1)
            logits = model(data)
        preds = logits.max(dim=1)[1]
        test_true.append(label.cpu().numpy())
        test_pred.append(preds.detach().cpu().numpy())
    test_true = np.concatenate(test_true)
    test_pred = np.concatenate(test_pred)
    test_acc = metrics.accuracy_score(test_true, test_pred)
    avg_per_class_acc = metrics.balanced_accuracy_score(test_true, test_pred)
    outstr = 'Test :: test acc: %.6f, test avg acc: %.6f' % (test_acc,
                                                             avg_per_class_acc)
    io.cprint(outstr)