Example #1
0
    def __init__(self, train_loader, val_loader, args, njoints, device, idx):
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.train_num_batches = len(self.train_loader)
        self.val_num_batches = len(self.val_loader)
        self.batch_size = args.train_batch
        self.num_classes = njoints
        self.device = device
        self.idx = idx
        self.epochs = args.epochs

        self.schedule = args.schedule
        self.gamma = args.gamma
        self.sigma_decay = args.sigma_decay
        self.checkpoint = args.checkpoint
        self.snapshot = args.snapshot
        self.num_stacks = args.stacks
        self.debug = args.debug
        self.flip = args.flip

        print("==> creating model '{}', stacks={}, blocks={}".format(
            args.arch, args.stacks, args.blocks))
        self.netG, self.netsD, self.start_epoch, self.best_acc, self.logger, self.optimizerG, self.optimizersD, self.domainD = \
                                      self.build_models(num_stacks=args.stacks,
                                      num_blocks=args.blocks,
                                      num_classes=njoints,
                                      resnet_layers=args.resnet_layers,
                                      device=device,
                                      args=args)
        self.real_labels, self.fake_labels = self.prepare_labels()

        self.criterion = losses.JointsMSELoss().to(device)
Example #2
0
def forward_pass(args):

    njoints = datasets.__dict__[args.dataset].njoints

    print("==> creating model '{}', stacks={}, blocks={}".format(
        args.arch, args.stacks, args.blocks))
    model = models.__dict__[args.arch](num_stacks=args.stacks,
                                       num_blocks=args.blocks,
                                       num_classes=njoints,
                                       resnet_layers=args.resnet_layers)

    model = torch.nn.DataParallel(model).to(device)

    # define loss function (criterion) and optimizer
    criterion = losses.JointsMSELoss().to(device)

    if args.solver == 'rms':
        optimizer = torch.optim.RMSprop(model.parameters(),
                                        lr=args.lr,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay)
    elif args.solver == 'adam':
        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=args.lr,
        )
    else:
        print('Unknown solver: {}'.format(args.solver))
        assert False

    lr = args.lr
Example #3
0
    def __init__(self):

        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        cudnn.benchmark = True

        self.img_path = ''
        self.dataset = 'mpii'
        self.image_path = ''
        self.inp_res = 256
        self.out_res = 64

        self.arch = 'hg'
        self.stacks = 2
        self.blocks = 1
        self.features = 256
        self.resnet_layers = 50

        self.solver = 'rms'
        self.workers = 1
        self.epochs = 100
        self.test_batch = 1
        self.train_batch = 1
        self.lr = 2.5e-4
        self.momentum = 0
        self.weight_decay = 0
        self.gamma = 0.1

        self.sigma = 1.0
        self.scale_factor = 0.25
        self.rot_factor = 1
        self.sigma_decay = 0

        #self.checkpoint=''
        self.resume = ''
        self.njoints = 24

        self.model = models.hg(num_stacks=self.stacks,
                               num_blocks=self.blocks,
                               num_classes=self.njoints,
                               resnet_layers=self.resnet_layers)

        self.model = torch.nn.DataParallel(self.model).to(self.device)
        self.criterion = losses.JointsMSELoss().to(self.device)
        self.optimizer = torch.optim.RMSprop(self.model.parameters(),
                                             lr=self.lr,
                                             momentum=self.momentum,
                                             weight_decay=self.weight_decay)

        # TODO changing hardcoded model path to parser
        self.checkpoint = torch.load(
            "/home/rbccps2080ti/projects/link_speed_estimation/ta_darknet/hourglass/checkpoint/mpii/hg_updated_21/checkpoint.pth.tar"
        )
        self.start_epoch = self.checkpoint['epoch']
        self.model.load_state_dict(self.checkpoint['state_dict'])
        self.model.eval()
Example #4
0
def main(args):
    global best_acc
    global idx1
    global idx2

    # idx is the index of joints used to compute accuracy for dataset2

    idx1 = range(1, 19)
    idx2 = range(1, 19)  # horse

    # create model
    njoints = datasets.__dict__[args.dataset].njoints
    print("==> creating model '{}'".format(args.arch))
    model = models.__dict__[args.arch](num_classes=njoints,
                                       resnet_layers=args.resnet_layers,
                                       pretrained=None,
                                       dual_branch=True)

    model = torch.nn.DataParallel(model).to(device)

    # define loss function (criterion) and optimizer
    criterion = losses.JointsMSELoss().to(device)

    # optionally resume from a checkpoint
    if args.resume:
        if isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_acc = checkpoint['best_acc']
            model.load_state_dict(checkpoint['state_dict_ema'])

            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        raise Exception('please provide a checkpoint')

    val_dataset = datasets.__dict__[args.dataset](is_train=False,
                                                  is_aug=False,
                                                  **vars(args))
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.test_batch,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    _, acc, predictions = validate(val_loader, model, criterion, njoints, args,
                                   args.flip, args.test_batch)
    return
Example #5
0
def main(args):
    global best_acc
    global idx

    # idx is the index of joints used to compute accuracy
    if args.dataset in ['mpii', 'lsp']:
        idx = [1, 2, 3, 4, 5, 6, 11, 12, 15, 16]
    elif args.dataset == 'coco':
        idx = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]
    else:
        print("Unknown dataset: {}".format(args.dataset))
        assert False

    # create checkpoint dir
    if not isdir(args.checkpoint):
        mkdir_p(args.checkpoint)

    # create model
    njoints = datasets.__dict__[args.dataset].njoints

    print("==> creating model '{}', stacks={}, blocks={}".format(
        args.arch, args.stacks, args.blocks))
    model = models.__dict__[args.arch](num_stacks=args.stacks,
                                       num_blocks=args.blocks,
                                       num_classes=njoints,
                                       resnet_layers=args.resnet_layers)

    model = torch.nn.DataParallel(model).to(device)

    # define loss function (criterion) and optimizer
    criterion = losses.JointsMSELoss().to(device)

    if args.solver == 'rms':
        optimizer = torch.optim.RMSprop(model.parameters(),
                                        lr=args.lr,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay)
    elif args.solver == 'adam':
        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=args.lr,
        )
    else:
        print('Unknown solver: {}'.format(args.solver))
        assert False

    # optionally resume from a checkpoint
    title = args.dataset + ' ' + args.arch
    if args.resume:
        if isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_acc = checkpoint['best_acc']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            logger = Logger(join(args.checkpoint, 'log.txt'),
                            title=title,
                            resume=True)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        logger = Logger(join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names(
            ['Epoch', 'LR', 'Train Loss', 'Val Loss', 'Train Acc', 'Val Acc'])

    print('    Total params: %.2fM' %
          (sum(p.numel() for p in model.parameters()) / 1000000.0))

    # create data loader
    train_dataset = datasets.__dict__[args.dataset](is_train=True,
                                                    **vars(args))
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.train_batch,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    val_dataset = datasets.__dict__[args.dataset](is_train=False, **vars(args))
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.test_batch,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # evaluation only
    if args.evaluate:
        print('\nEvaluation only')
        loss, acc, predictions = validate(val_loader, model, criterion,
                                          njoints, args.debug, args.flip)
        save_pred(predictions, checkpoint=args.checkpoint)
        return

    # train and eval
    lr = args.lr
    for epoch in range(args.start_epoch, args.epochs):
        lr = adjust_learning_rate(optimizer, epoch, lr, args.schedule,
                                  args.gamma)
        print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr))

        # decay sigma
        if args.sigma_decay > 0:
            train_loader.dataset.sigma *= args.sigma_decay
            val_loader.dataset.sigma *= args.sigma_decay

        # train for one epoch
        train_loss, train_acc = train(train_loader, model, criterion,
                                      optimizer, args.debug, args.flip)

        # evaluate on validation set
        valid_loss, valid_acc, predictions = validate(val_loader, model,
                                                      criterion, njoints,
                                                      args.debug, args.flip)

        # append logger file
        logger.append(
            [epoch + 1, lr, train_loss, valid_loss, train_acc, valid_acc])

        # remember best acc and save checkpoint
        is_best = valid_acc > best_acc
        best_acc = max(valid_acc, best_acc)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc': best_acc,
                'optimizer': optimizer.state_dict(),
            },
            predictions,
            is_best,
            checkpoint=args.checkpoint,
            snapshot=args.snapshot)

    logger.close()
    logger.plot(['Train Acc', 'Val Acc'])
    savefig(os.path.join(args.checkpoint, 'log.eps'))
def main(args):
    global best_acc
    global idx1
    global idx2

    # idx is the index of joints used to compute accuracy for dataset2
    if args.dataset1 == 'real_animal' or args.dataset1 == 'real_animal_sp' or args.dataset1 == 'synthetic_animal' or args.dataset1 == 'synthetic_animal_sp':
        idx1 = range(1, 19)
    else:
        print("Unknown dataset: {}".format(args.dataset1))
        assert False

    if args.dataset2 == 'real_animal' or args.dataset2 == 'real_animal_sp':
        if args.animal == 'horse':
            idx2 = range(1, 19)  # horse
        elif args.animal == 'tiger':
            idx2 = [
                1, 2, 3, 4, 5, 6, 7, 8, 15, 16, 17, 18, 13, 14, 9, 10, 11, 12
            ]  # tiger
    else:
        print("Unknown dataset: {}".format(args.dataset2))
        assert False

    # create checkpoint dir
    if not isdir(args.checkpoint):
        mkdir_p(args.checkpoint)

    # create model
    njoints = datasets.__dict__[args.dataset1].njoints
    print(njoints)
    print("==> creating model '{}', stacks={}, blocks={}".format(
        args.arch, args.stacks, args.blocks))
    model = models.__dict__[args.arch](num_stacks=args.stacks,
                                       num_blocks=args.blocks,
                                       num_classes=njoints,
                                       resnet_layers=args.resnet_layers)

    model = torch.nn.DataParallel(model).to(device)

    # define loss function (criterion) and optimizer
    criterion = losses.JointsMSELoss().to(device)

    if args.solver == 'rms':
        optimizer = torch.optim.RMSprop(model.parameters(),
                                        lr=args.lr,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay)
    elif args.solver == 'adam':
        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=args.lr,
        )
    else:
        print('Unknown solver: {}'.format(args.solver))
        assert False

    # optionally resume from a checkpoint
    title = args.dataset1 + ' ' + args.arch
    if args.resume:
        if isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            #model.load_state_dict(checkpoint)
            args.start_epoch = checkpoint['epoch']
            best_acc = checkpoint['best_acc']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        raise Exception('please provide a checkpoint')

    print('    Total params: %.2fM' %
          (sum(p.numel() for p in model.parameters()) / 1000000.0))

    # create data loader
    train_dataset = datasets.__dict__[args.dataset1](is_train=True,
                                                     **vars(args))
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.train_batch,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    val_dataset = datasets.__dict__[args.dataset2](is_train=False,
                                                   **vars(args))
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.test_batch,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # evaluation only
    if args.evaluate:
        print('\nEvaluation only')
        _, acc, predictions = validate(val_loader, model, criterion, njoints,
                                       args, args.flip, args.test_batch)
        #save_pred(predictions, checkpoint=args.checkpoint)
        return
Example #7
0
def main(args):
    global best_acc
    global idx
    global customMpiiObject

    # idx is the index of joints used to compute accuracy
    if args.dataset in ['mpii', 'lsp']:
        idx = [1, 2, 3, 4, 5, 6, 11, 12, 15, 16]
    elif args.dataset == 'coco':
        idx = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]
    else:
        print("Unknown dataset: {}".format(args.dataset))
        assert False

    # create checkpoint dir
    if not isdir(args.checkpoint):
        mkdir_p(args.checkpoint)

    # create model
    njoints = CustomMpii.njoints

    print("==> creating model '{}', stacks={}, blocks={}".format(
        args.arch, args.stacks, args.blocks))
    model = models.__dict__[args.arch](num_stacks=args.stacks,
                                       num_blocks=args.blocks,
                                       num_classes=njoints,
                                       resnet_layers=args.resnet_layers)

    model = torch.nn.DataParallel(model).to(device)

    # define loss function (criterion) and optimizer
    criterion = losses.JointsMSELoss().to(device)

    if args.solver == 'rms':
        optimizer = torch.optim.RMSprop(model.parameters(),
                                        lr=args.lr,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay)
    elif args.solver == 'adam':
        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=args.lr,
        )
    else:
        print('Unknown solver: {}'.format(args.solver))
        assert False

    # optionally resume from a checkpoint
    title = args.dataset + ' ' + args.arch
    if args.resume:
        if isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cpu')
            args.start_epoch = checkpoint['epoch']
            best_acc = checkpoint['best_acc']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            logger = Logger(join(args.checkpoint, 'log.txt'),
                            title=title,
                            resume=True)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        logger = Logger(join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names(
            ['Epoch', 'LR', 'Train Loss', 'Val Loss', 'Train Acc', 'Val Acc'])

    print('    Total params: %.2fM' %
          (sum(p.numel() for p in model.parameters()) / 1000000.0))

    print("Loading validation dataset...")
    customMpiiObject = CustomMpii(is_train=False, **vars(args))

    loss, acc, predictions = validate(args.unit_path, model, criterion,
                                      njoints, args.debug, args.flip)
    print(predictions)

    if args.debug:
        paint.main(args.unit_path, predictions.squeeze().tolist())

    return
from pose.utils.imutils import batch_with_heatmap
from pose.utils.transforms import fliplr, flip_back
import pose.models as models
import pose.losses as losses

import scripts.ssl_datasets as datasets
from options.train_options import TrainOptions
from CCSSL.scripts.timer import Timer
from CCSSL.scripts.consistency import prediction_check

opt = TrainOptions()
args = opt.initialize()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
idx = range(1, args.num_classes + 1)
njoints = len(idx)
criterion = losses.JointsMSELoss().to(device)
global_animal = args.animal


def main():

    _t = {'iter time': Timer()}

    # create directory
    model_name = args.source + '_to_' + args.target
    if not os.path.exists(args.checkpoint):
        os.makedirs(args.checkpoint)
        os.makedirs(os.path.join(args.checkpoint, 'logs'))
        os.makedirs(os.path.join(args.checkpoint, 'ssl_labels'))
    opt.print_options(args)
def main(args):

    img_path = "/home/shantam/Documents/Programs/hourglasstensorlfow/images/cropped0.jpg"
    img = load_image(img_path)
    c = [img.shape[1] / 2, img.shape[2] / 2]
    s = float(img.shape[1] / 200.0)

    img = crop(img_path, img, c, s, [256, 256])
    trans = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    #img = trans(img)
    img = img.unsqueeze(0)
    print(img.shape)

    njoints = 24

    model = models.__dict__[args.arch](num_stacks=args.stacks,
                                       num_blocks=args.blocks,
                                       num_classes=njoints,
                                       resnet_layers=args.resnet_layers)

    model = torch.nn.DataParallel(model).to(device)
    criterion = losses.JointsMSELoss().to(device)

    if args.solver == 'rms':
        print("done")
        optimizer = torch.optim.RMSprop(model.parameters(),
                                        lr=args.lr,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay)
    elif args.solver == 'adam':
        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=args.lr,
        )

    checkpoint = torch.load(
        "/home/shantam/Documents/Programs/pytorch-pose/checkpoint/mpii/hg_fullset/checkpoint.pth.tar"
    )
    args.start_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['state_dict'])
    model.eval()
    img = img.to(device, non_blocking=True)
    output = model(img)

    #print (len(output))

    score_map = output[-1].cpu() if type(output) == list else output.cpu()
    preds, vals = final_preds(score_map, [c], [s], [64, 64])

    image = cv2.imread(img_path)
    coords = np.squeeze(preds)

    for m in range(0, len(coords)):
        val = vals[0][m].detach().numpy()
        print(val)
        if val > 0.25:  #threshold for confidence score
            x, y = coords[m][0].cpu().detach().numpy(), coords[m][1].cpu(
            ).detach().numpy()
            print(x, y)
            cv2.circle(image, (x, y), 1, (0, 0, 255), -1)

    while True:
        cv2.imshow("dec", image)
        cv2.waitKey(10)