Exemplo n.º 1
0
def main(args):
    print_args(args)

    plfd_backbone = PFLDInference().cuda()
    if args.resume:
        try:
            plfd_backbone.load_state_dict(
                torch.load(args.resume,
                           map_location=lambda storage, loc: storage))
            logging.info("load %s successfully ! " % args.resume)
        except KeyError:
            plfd_backbone = torch.nn.DataParallel(plfd_backbone)
            plfd_backbone.load_state_dict(torch.load(args.resume))

    step_epoch = [int(x) for x in args.step.split(',')]
    if args.loss == 'mse':
        criterion = MSELoss()
    elif args.loss == 'sommthl1':
        criterion = SmoothL1()
    elif args.loss == 'wing':
        criterion = WingLoss()
    elif args.loss == 'pfld':
        criterion = PFLDLoss()
    cur_lr = args.base_lr
    optimizer = torch.optim.Adam(plfd_backbone.parameters(),
                                 lr=args.base_lr,
                                 weight_decay=args.weight_decay)

    # SGD optimizer is very sensitive to the magnitude of loss,
    # which is likely to be non convergent or nan, I recommend Adam.
    # optimizer = torch.optim.SGD(plfd_backbone.parameters(), lr=args.base_lr, momentum=0.9, weight_decay=args.weight_decay)

    train_transform = transforms.Compose([
        transforms.RandomGrayscale(p=0.2),
        transforms.ColorJitter(0.2, 0.2, 0.2, 0.2),
        transforms.ToTensor()
    ])
    wlfwdataset = WLFWDatasets(args.dataroot, train_transform)
    dataloader = DataLoader(wlfwdataset,
                            batch_size=args.train_batchsize,
                            shuffle=True,
                            num_workers=args.workers,
                            drop_last=False)
    val_transform = transforms.Compose([transforms.ToTensor()])
    wlfw_val_dataset = WLFWDatasets(args.val_dataroot, val_transform)
    wlfw_val_dataloader = DataLoader(wlfw_val_dataset,
                                     batch_size=args.val_batchsize,
                                     shuffle=False,
                                     num_workers=args.workers)

    step_index = 0
    writer = SummaryWriter(args.tensorboard)
    for epoch in range(args.start_epoch, args.end_epoch + 1):
        train_pose_loss, train_lds_loss = train(dataloader, plfd_backbone,
                                                criterion, optimizer, epoch)
        filename = os.path.join(str(args.snapshot),
                                "checkpoint_epoch_" + str(epoch) + '.pth')
        save_checkpoint(plfd_backbone.state_dict(), filename)
        val_pose_loss, val_lds_loss = validate(wlfw_val_dataloader,
                                               plfd_backbone, criterion, epoch)
        if epoch in step_epoch:
            step_index += 1
            cur_lr = adjust_learning_rate(optimizer, args.base_lr, step_index)

        print(
            'Epoch: %d, train pose loss: %6.4f, train lds loss:%6.4f, val pose MAE:%6.4f, val lds MAE:%6.4f, lr:%8.6f'
            % (epoch, train_pose_loss, train_lds_loss, val_pose_loss,
               val_lds_loss, cur_lr))
        writer.add_scalar('data/pose_loss', train_pose_loss, epoch)
        writer.add_scalars(
            'data/loss', {
                'val pose loss': val_pose_loss,
                'val lds loss': val_lds_loss,
                'train loss': train_lds_loss
            }, epoch)

    writer.close()
Exemplo n.º 2
0
def main(args):
    # Step 1: parse args config
    logging.basicConfig(
        format=
        '[%(asctime)s] [p%(process)s] [%(pathname)s:%(lineno)d] [%(levelname)s] %(message)s',
        level=logging.INFO,
        handlers=[
            logging.FileHandler(args.log_file, mode='w'),
            logging.StreamHandler()
        ])
    print_args(args)

    # Step 2: model, criterion, optimizer, scheduler
    plfd_backbone = PFLDInference().cuda()
    auxiliarynet = AuxiliaryNet().cuda()
    criterion = PFLDLoss()
    optimizer = torch.optim.Adam([{
        'params': plfd_backbone.parameters()
    }, {
        'params': auxiliarynet.parameters()
    }],
                                 lr=args.base_lr,
                                 weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', patience=args.lr_patience, verbose=True)

    # step 3: data
    # argumetion
    transform = transforms.Compose([transforms.ToTensor()])
    wlfwdataset = WLFWDatasets(args.dataroot, transform)
    dataloader = DataLoader(wlfwdataset,
                            batch_size=args.train_batchsize,
                            shuffle=True,
                            num_workers=args.workers,
                            drop_last=False)

    wlfw_val_dataset = WLFWDatasets(args.val_dataroot, transform)
    wlfw_val_dataloader = DataLoader(wlfw_val_dataset,
                                     batch_size=args.val_batchsize,
                                     shuffle=False,
                                     num_workers=args.workers)

    # step 4: run
    writer = SummaryWriter(args.tensorboard)
    for epoch in range(args.start_epoch, args.end_epoch + 1):
        weighted_train_loss, train_loss = train(dataloader, plfd_backbone,
                                                auxiliarynet, criterion,
                                                optimizer, epoch)
        filename = os.path.join(str(args.snapshot),
                                "checkpoint_epoch_" + str(epoch) + '.pth.tar')
        save_checkpoint(
            {
                'epoch': epoch,
                'plfd_backbone': plfd_backbone.state_dict(),
                'auxiliarynet': auxiliarynet.state_dict()
            }, filename)

        val_loss = validate(wlfw_val_dataloader, plfd_backbone, auxiliarynet,
                            criterion, epoch)

        scheduler.step(val_loss)
        writer.add_scalar('data/weighted_loss', weighted_train_loss, epoch)
        writer.add_scalars('data/loss', {
            'val loss': val_loss,
            'train loss': train_loss
        }, epoch)
    writer.close()
Exemplo n.º 3
0
def main(args):
    # Step 1: parse args config
    logging.basicConfig(
        format=
        '[%(asctime)s] [p%(process)s] [%(pathname)s:%(lineno)d] [%(levelname)s] %(message)s',
        level=logging.INFO,
        handlers=[
            logging.FileHandler(args.log_file, mode='w'),
            logging.StreamHandler()
        ])
    print_args(args)

    # Step 2: model, criterion, optimizer, scheduler
    if wandb.config.pfld_backbone == "GhostNet":
        plfd_backbone = CustomizedGhostNet(width=wandb.config.ghostnet_width, dropout=0.2)
        logger.info(f"Using GHOSTNET with width={wandb.config.ghostnet_width} as backbone of PFLD backbone")

        # If using pretrained weight from ghostnet model trained on image net
        if (wandb.config.ghostnet_with_pretrained_weight_image_net == True):
            logger.info(f"Using pretrained weights of ghostnet model trained on image net data ")
            plfd_backbone = load_pretrained_weight_imagenet_for_ghostnet_backbone(
                plfd_backbone, "./checkpoint_imagenet/state_dict_93.98.pth")
            


    else:
        plfd_backbone = PFLDInference().to(device) # MobileNet2 defaut
        logger.info("Using MobileNet2 as backbone of PFLD backbone")

    auxiliarynet = AuxiliaryNet().to(device)

    # Watch model by wandb
    wandb.watch(plfd_backbone)
    wandb.watch(auxiliarynet)

    criterion = PFLDLoss()
    optimizer = torch.optim.Adam(
        [{
            'params': plfd_backbone.parameters()
        }, {
            'params': auxiliarynet.parameters()
        }],
        lr=args.base_lr,
        weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=args.lr_patience, verbose=True)

    # step 3: data
    # argumetion
    transform = transforms.Compose([transforms.ToTensor()])
    wlfwdataset = WLFWDatasets(args.dataroot, transform)
    dataloader = DataLoader(
        wlfwdataset,
        batch_size=args.train_batchsize,
        shuffle=True,
        num_workers=args.workers,
        drop_last=False)

    wlfw_val_dataset = WLFWDatasets(args.val_dataroot, transform)
    wlfw_val_dataloader = DataLoader(
        wlfw_val_dataset,
        batch_size=args.val_batchsize,
        shuffle=False,
        num_workers=args.workers)

    # step 4: run
    writer = SummaryWriter(args.tensorboard)
    for epoch in range(args.start_epoch, args.end_epoch + 1):
        weighted_train_loss, train_loss = train(dataloader, plfd_backbone, auxiliarynet,
                                      criterion, optimizer, epoch)
        filename = os.path.join(
            str(args.snapshot), "checkpoint_epoch_" + str(epoch) + '.pth.tar')
        save_checkpoint({
            'epoch': epoch,
            'plfd_backbone': plfd_backbone.state_dict(),
            'auxiliarynet': auxiliarynet.state_dict()
        }, filename)

        val_loss = validate(wlfw_val_dataloader, plfd_backbone, auxiliarynet,
                            criterion)
        
        wandb.log({"metric/val_loss": val_loss})

        scheduler.step(val_loss)
        writer.add_scalar('data/weighted_loss', weighted_train_loss, epoch)
        writer.add_scalars('data/loss', {'val loss': val_loss, 'train loss': train_loss}, epoch)
    writer.close()