コード例 #1
0
        rmax = img_width
        rmin -= delt
    if cmax > img_length:
        delt = cmax - img_length
        cmax = img_length
        cmin -= delt
    return rmin, rmax, cmin, cmax


####################################################################################################
################################### load BiSeNet parameters ########################################
####################################################################################################
print('load BiseNet')
start_time = time.time()
bise_model = BiSeNet(opt.num_classes, opt.context_path)
bise_model = bise_model.cuda()
bise_model.load_state_dict(torch.load(opt.checkpoint_path))
global bise_model
print('Done!')
print("Load time : {}".format(time.time() - start_time))

#####################################################################################################
######################## load Densefusion Netwopy4thork, 3d model #############################
#####################################################################################################
print('load densefusion network')
start_time = time.time()
estimator = PoseNet(num_points=num_points, num_obj=num_obj)
estimator.cuda()
estimator.load_state_dict(torch.load(opt.model))
estimator.eval()
############################################################################
コード例 #2
0
def main(params):
    # basic parameters
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_epochs',
                        type=int,
                        default=300,
                        help='Number of epochs to train for')
    parser.add_argument('--epoch_start_i',
                        type=int,
                        default=0,
                        help='Start counting epochs from this number')
    parser.add_argument('--checkpoint_step',
                        type=int,
                        default=10,
                        help='How often to save checkpoints (epochs)')
    parser.add_argument('--validation_step',
                        type=int,
                        default=10,
                        help='How often to perform validation (epochs)')
    parser.add_argument('--batch_size',
                        type=int,
                        default=1,
                        help='Number of images in each batch')
    parser.add_argument(
        '--context_path',
        type=str,
        default="resnet101",
        help='The context path model you are using, resnet18, resnet101.')
    parser.add_argument('--learning_rate',
                        type=float,
                        default=0.01,
                        help='learning rate used for train')
    parser.add_argument('--data',
                        type=str,
                        default='data',
                        help='path of training data')
    parser.add_argument('--num_workers',
                        type=int,
                        default=4,
                        help='num of workers')
    parser.add_argument('--num_classes',
                        type=int,
                        default=32,
                        help='num of object classes (with void)')
    parser.add_argument('--cuda',
                        type=str,
                        default='0',
                        help='GPU ids used for training')
    parser.add_argument('--use_gpu',
                        type=bool,
                        default=True,
                        help='whether to user gpu for training')
    parser.add_argument('--pretrained_model_path',
                        type=str,
                        default=None,
                        help='path to pretrained model')
    parser.add_argument('--save_model_path',
                        type=str,
                        default="checkpoints",
                        help='path to save model')
    parser.add_argument('--optimizer',
                        type=str,
                        default='rmsprop',
                        help='optimizer, support rmsprop, sgd, adam')
    parser.add_argument('--loss',
                        type=str,
                        default='crossentropy',
                        help='loss function, dice or crossentropy')

    # settiamo i nostri parametri
    args = parser.parse_args(params)

    # create dataset and dataloader
    train_path = args.data
    train_transform, val_transform = get_transform()

    # creiamo un oggetto di tipo VOC per il training
    dataset_train = VOC(train_path,
                        image_set="train",
                        transform=train_transform)
    dataloader_train = DataLoader(dataset_train,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  drop_last=True)

    # creiamo un oggetto di tipo VOC per la validation
    dataset_val = VOC(train_path, image_set="val", transform=val_transform)
    dataloader_val = DataLoader(dataset_val,
                                batch_size=args.batch_size,
                                shuffle=True,
                                num_workers=args.num_workers)

    # build model
    os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda
    model = BiSeNet(args.num_classes, args.context_path)
    if torch.cuda.is_available() and args.use_gpu:
        model = model.cuda()

    # build optimizer
    if args.optimizer == 'rmsprop':
        optimizer = torch.optim.RMSprop(model.parameters(), args.learning_rate)
    elif args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(),
                                    args.learning_rate,
                                    momentum=0.9,
                                    weight_decay=1e-4)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), args.learning_rate)
    else:  # rmsprop
        print('not supported optimizer \n')
        return None

    # load pretrained model if exists
    # Non ce l'abbiamo
    if args.pretrained_model_path is not None:
        print('load model from %s ...' % args.pretrained_model_path)
        model.load_state_dict(torch.load(args.pretrained_model_path))
        print('Done!')

    # train
    # funzioni presenti in questo file
    train(args, model, optimizer, dataloader_train, dataloader_val)

    val(args, model, dataloader_val)