Пример #1
0
def train(args):
  
  assert torch.cuda.is_available(), 'CUDA is not available.'
  torch.backends.cudnn.enabled   = True
  torch.backends.cudnn.benchmark = True
  
  
  print('Arguments : -------------------------------')
  for name, value in args._get_kwargs():
    print('{:16} : {:}'.format(name, value))
    
    
  # Data Augmentation    
  mean_fill   = tuple( [int(x*255) for x in [0.485, 0.456, 0.406] ] )
  normalize   = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                      std=[0.229, 0.224, 0.225])


  # train_transform = [transforms.AugTransBbox(1, 0.5)]
  train_transform = [transforms.PreCrop(args.pre_crop_expand)]
  
  train_transform += [transforms.TrainScale2WH((1024, 1024))]
  train_transform += [transforms.AugHorizontalFlip(args.flip_prob)]

  train_transform += [transforms.ToTensor()]
  train_transform  = transforms.Compose( train_transform )


  # Training datasets
  train_data = GeneralDataset(args.num_pts, train_transform, args.train_lists)
  train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True)

  
    
  net = Model(args.num_pts)
  
  # print(len(net.children()))
  #for m in net.children():
  #  print(type(m))
  criterion = wing_loss(args) 

  optimizer = torch.optim.SGD(net.parameters(), lr=args.LR, momentum=args.momentum,
                          weight_decay=args.decay, nesterov=args.nesterov)
    
  scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.schedule, gamma=args.gamma)
  
  net = net.cuda()
  net = torch.nn.DataParallel(net)
    
    
    
  print('--------------', len(train_loader))
  for epoch in range(3):
    break
    for i , (inputs, target, mask, cropped_size) in enumerate(train_loader):
        
        
      target = target.squeeze(1)
      inputs = inputs.cuda()
      target = target.cuda()
      mask = mask.cuda()
      prediction = net(inputs)            
      loss = criterion(prediction, target, mask) 
      
      nums_img = inputs.size()[0]
      for j in range(nums_img): 
        temp_img = inputs[j].permute(1,2,0)
        temp_img = temp_img.mul(255).numpy()
        temp_img = cv2.cvtColor(temp_img, cv2.COLOR_RGB2BGR)

        pts = []
        for d in range(args.num_pts):
          pts.append((target[j][0][2*d].item(), target[j][0][2*d+1].item()))
        bbox = [int(index[0].item())  for index in meta]
        #print(pts)
        draw_points(temp_img, pts, (0, 255, 255))
        #draw_points(temp_img, [(bbox[0],bbox[1])], (0, 0, 255))
        #draw_points(temp_img, [(bbox[2],bbox[3])], (0, 0, 255))
        cv2.rectangle(temp_img,(bbox[0],bbox[1]),(bbox[2],bbox[3]),(0,255,0),4)
        cv2.imwrite('{}-{}-{}.jpg'.format(epoch,i,j), temp_img)
        # assert 1==0
      #if i > 5:
      #  break
  for a, v in enumerate(train_data.data_value):
   
    image = cv2.imread(v['image_path'])
    meta = v['meta']
    # bbox = v['bbox']
    bbox = v['meta'].get_box()
    pts = []
    for d in range(args.num_pts):
      pts.append((meta.points[0, d], meta.points[1, d]))
    draw_points(image, pts, (0, 255, 255))
    cv2.rectangle(image,(int(bbox[0]), int(bbox[1])),(int(bbox[2]), int(bbox[3])),(0,255,0),4)
    cv2.imwrite('ori_{}.jpg'.format(a), image)
Пример #2
0
def main(args):
  assert torch.cuda.is_available(), 'CUDA is not available.'
  torch.backends.cudnn.enabled   = True
  torch.backends.cudnn.benchmark = True
  prepare_seed(args.rand_seed)

  logstr = 'seed-{:}-time-{:}'.format(args.rand_seed, time_for_file())
  logger = Logger(args.save_path, logstr)
  logger.log('Main Function with logger : {:}'.format(logger))
  logger.log('Arguments : -------------------------------')
  for name, value in args._get_kwargs():
    logger.log('{:16} : {:}'.format(name, value))
  logger.log("Python  version : {}".format(sys.version.replace('\n', ' ')))
  logger.log("Pillow  version : {}".format(PIL.__version__))
  logger.log("PyTorch version : {}".format(torch.__version__))
  logger.log("cuDNN   version : {}".format(torch.backends.cudnn.version()))

  # General Data Argumentation
  mean_fill   = tuple( [int(x*255) for x in [0.485, 0.456, 0.406] ] )
  normalize   = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                      std=[0.229, 0.224, 0.225])
  print('The flip is : {}, rotate is {}'.format(args.arg_flip, args.rotate_max))
  train_transform  = [transforms.PreCrop(args.pre_crop_expand)]
  train_transform += [transforms.TrainScale2WH((args.crop_width, args.crop_height))]
  train_transform += [transforms.AugScale(args.scale_prob, args.scale_min, args.scale_max)]
  if args.arg_flip:
    train_transform += [transforms.AugHorizontalFlip(args.flip_prob)]
  if args.rotate_max:
    train_transform += [transforms.AugRotate(args.rotate_max)]
  train_transform += [transforms.AugCrop(args.crop_width, args.crop_height, args.crop_perturb_max, mean_fill)]
  train_transform += [transforms.ToTensor(), normalize]
  train_transform  = transforms.Compose( train_transform )

  eval_transform  = transforms.Compose([transforms.PreCrop(args.pre_crop_expand), transforms.TrainScale2WH((args.crop_width, args.crop_height)),  transforms.ToTensor(), normalize])
  assert (args.scale_min+args.scale_max) / 2 == args.scale_eval, 'The scale is not ok : {},{} vs {}'.format(args.scale_min, args.scale_max, args.scale_eval)
  
  # Model Configure Load
  model_config = load_configure(args.model_config, logger)
  args.sigma   = args.sigma * args.scale_eval
  logger.log('Real Sigma : {:}'.format(args.sigma))

  # Training Dataset
  if args.regression:
    train_data   = GeneralDatasetForRegression(train_transform, args.data_indicator)
  else:
    train_data   = GeneralDataset(train_transform, args.sigma, model_config.downsample, args.heatmap_type, args.data_indicator)
  train_data.load_list(args.train_lists, args.num_pts, True)
  train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)


  # Evaluation Dataloader
  eval_loaders = []
  if args.eval_vlists is not None:
    for eval_vlist in args.eval_vlists:
      if args.regression:
        eval_vdata = GeneralDatasetForRegression(eval_transform, args.data_indicator)
      else:
        eval_vdata = GeneralDataset(eval_transform, args.sigma, model_config.downsample, args.heatmap_type, args.data_indicator)
      eval_vdata.load_list(eval_vlist, args.num_pts, True)
      eval_vloader = torch.utils.data.DataLoader(eval_vdata, batch_size=args.batch_size, shuffle=False,
                                                 num_workers=args.workers, pin_memory=True)
      eval_loaders.append((eval_vloader, True))

  if args.eval_ilists is not None:
    for eval_ilist in args.eval_ilists:
      if args.regression:
        eval_idata = GeneralDatasetForRegression(eval_transform, args.data_indicator)
      else:
        eval_idata = GeneralDataset(eval_transform, args.sigma, model_config.downsample, args.heatmap_type, args.data_indicator)
      eval_idata.load_list(eval_ilist, args.num_pts, True)
      eval_iloader = torch.utils.data.DataLoader(eval_idata, batch_size=args.batch_size, shuffle=False,
                                                 num_workers=args.workers, pin_memory=True)
      eval_loaders.append((eval_iloader, False))

  # Define network
  logger.log('configure : {:}'.format(model_config))
  if args.regression:
    net = obtain_model(model_config, args.num_pts)
  else:
    net = obtain_model(model_config, args.num_pts + 1)
  assert model_config.downsample == net.downsample, 'downsample is not correct : {} vs {}'.format(model_config.downsample, net.downsample)
  logger.log("=> network :\n {}".format(net))

  logger.log('Training-data : {:}'.format(train_data))
  for i, eval_loader in enumerate(eval_loaders):
    eval_loader, is_video = eval_loader
    logger.log('The [{:2d}/{:2d}]-th testing-data [{:}] = {:}'.format(i, len(eval_loaders), 'video' if is_video else 'image', eval_loader.dataset))
    
  logger.log('arguments : {:}'.format(args))

  opt_config = load_configure(args.opt_config, logger)

  if hasattr(net, 'specify_parameter'):
    net_param_dict = net.specify_parameter(opt_config.LR, opt_config.Decay)
  else:
    net_param_dict = net.parameters()

  optimizer, scheduler, criterion = obtain_optimizer(net_param_dict, opt_config, logger)
  logger.log('criterion : {:}'.format(criterion))
  net, criterion = net.cuda(), criterion.cuda()
  net = torch.nn.DataParallel(net)

  last_info = logger.last_info()
  if last_info.exists():
    logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info))
    last_info = torch.load(last_info)
    start_epoch = last_info['epoch'] + 1
    checkpoint  = torch.load(last_info['last_checkpoint'])
    assert last_info['epoch'] == checkpoint['epoch'], 'Last-Info is not right {:} vs {:}'.format(last_info, checkpoint['epoch'])
    net.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    scheduler.load_state_dict(checkpoint['scheduler'])
    logger.log("=> load-ok checkpoint '{:}' (epoch {:}) done" .format(logger.last_info(), checkpoint['epoch']))
  else:
    logger.log("=> do not find the last-info file : {:}".format(last_info))
    start_epoch = 0


  if args.eval_once:
    logger.log("=> only evaluate the model once")
    eval_results = eval_all(args, eval_loaders, net, criterion, 'eval-once', logger, opt_config)
    logger.close() ; return


  # Main Training and Evaluation Loop
  start_time = time.time()
  epoch_time = AverageMeter()
  for epoch in range(start_epoch, opt_config.epochs):

    scheduler.step()
    need_time = convert_secs2time(epoch_time.avg * (opt_config.epochs-epoch), True)
    epoch_str = 'epoch-{:03d}-{:03d}'.format(epoch, opt_config.epochs)
    LRs       = scheduler.get_lr()
    logger.log('\n==>>{:s} [{:s}], [{:s}], LR : [{:.5f} ~ {:.5f}], Config : {:}'.format(time_string(), epoch_str, need_time, min(LRs), max(LRs), opt_config))

    # train for one epoch
    if args.regression:
      train_loss, train_nme = basic_train_regression(args, train_loader, net, criterion, optimizer, epoch_str, logger, opt_config)
    else:
      train_loss, train_nme = basic_train(args, train_loader, net, criterion, optimizer, epoch_str, logger, opt_config)

    # log the results    
    logger.log('==>>{:s} Train [{:}] Average Loss = {:.6f}, NME = {:.2f}'.format(time_string(), epoch_str, train_loss, train_nme*100))

    # remember best prec@1 and save checkpoint
    save_path = save_checkpoint({
          'epoch': epoch,
          'args' : deepcopy(args),
          'arch' : model_config.arch,
          'state_dict': net.state_dict(),
          'scheduler' : scheduler.state_dict(),
          'optimizer' : optimizer.state_dict(),
          }, logger.path('model') / '{:}-{:}.pth'.format(model_config.arch, epoch_str), logger)

    last_info = save_checkpoint({
          'epoch': epoch,
          'last_checkpoint': save_path,
          }, logger.last_info(), logger)

    if args.regression:
        eval_results = basic_eval_all_regression(args, eval_loaders, net, criterion, epoch_str, logger, opt_config)
    else:
        eval_results = basic_eval_all(args, eval_loaders, net, criterion, epoch_str, logger, opt_config)
    
    # measure elapsed time
    epoch_time.update(time.time() - start_time)
    start_time = time.time()

  logger.close()
Пример #3
0
def train(args):

    assert torch.cuda.is_available(), 'CUDA is not available.'
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True

    tfboard_writer = SummaryWriter()
    logname = '{}'.format(datetime.datetime.now().strftime('%Y-%m-%d-%H:%M'))
    logger = Logger(args.save_path, logname)
    logger.log('Arguments : -------------------------------')
    for name, value in args._get_kwargs():
        logger.log('{:16} : {:}'.format(name, value))

    # Data Augmentation
    mean_fill = tuple([int(x * 255) for x in [0.485, 0.456, 0.406]])
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_transform = [
        transforms.AugTransBbox(args.transbbox_prob, args.transbbox_percent)
    ]
    train_transform += [transforms.PreCrop(args.pre_crop_expand)]
    train_transform += [
        transforms.TrainScale2WH((args.crop_width, args.crop_height))
    ]
    #train_transform += [transforms.AugHorizontalFlip(args.flip_prob)]
    #train_transform += [transforms.AugScale(args.scale_prob, args.scale_min, args.scale_max)]
    #train_transform += [transforms.AugCrop(args.crop_width, args.crop_height, args.crop_perturb_max, mean_fill)]
    if args.rotate_max:
        train_transform += [transforms.AugRotate(args.rotate_max)]
    train_transform += [
        transforms.AugGaussianBlur(args.gaussianblur_prob,
                                   args.gaussianblur_kernel_size,
                                   args.gaussianblur_sigma)
    ]
    train_transform += [transforms.ToTensor(), normalize]
    train_transform = transforms.Compose(train_transform)

    eval_transform = transforms.Compose([
        transforms.PreCrop(args.pre_crop_expand),
        transforms.TrainScale2WH((args.crop_width, args.crop_height)),
        transforms.ToTensor(), normalize
    ])

    # Training datasets
    train_data = GeneralDataset(args.num_pts, train_transform,
                                args.train_lists)
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    # Evaluation Dataloader
    eval_loaders = []

    for eval_ilist in args.eval_lists:
        eval_idata = GeneralDataset(args.num_pts, eval_transform, eval_ilist)
        eval_iloader = torch.utils.data.DataLoader(eval_idata,
                                                   batch_size=args.batch_size,
                                                   shuffle=False,
                                                   num_workers=args.workers,
                                                   pin_memory=True)
        eval_loaders.append(eval_iloader)

    net = Model(args.num_pts)

    logger.log("=> network :\n {}".format(net))
    logger.log('arguments : {:}'.format(args))

    optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                       net.parameters()),
                                lr=args.LR,
                                momentum=args.momentum,
                                weight_decay=args.decay,
                                nesterov=args.nesterov)

    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=args.schedule,
                                                     gamma=args.gamma)

    criterion = wing_loss(args)
    # criterion = torch.nn.MSELoss(reduce=True)

    net = net.cuda()
    criterion = criterion.cuda()
    net = torch.nn.DataParallel(net)

    last_info = logger.last_info()
    if last_info.exists():
        logger.log("=> loading checkpoint of the last-info '{:}' start".format(
            last_info))
        last_info = torch.load(last_info)
        start_epoch = last_info['epoch'] + 1
        checkpoint = torch.load(last_info['last_checkpoint'])
        assert last_info['epoch'] == checkpoint[
            'epoch'], 'Last-Info is not right {:} vs {:}'.format(
                last_info, checkpoint['epoch'])
        net.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        logger.log("=> load-ok checkpoint '{:}' (epoch {:}) done".format(
            logger.last_info(), checkpoint['epoch']))
    else:
        logger.log("=> do not find the last-info file : {:}".format(last_info))
        start_epoch = 0

    for epoch in range(start_epoch, args.epochs):
        scheduler.step()

        net.train()

        # train
        img_prediction = []
        img_target = []
        train_losses = AverageMeter()
        for i, (inputs, target) in enumerate(train_loader):

            target = target.squeeze(1)
            inputs = inputs.cuda()
            target = target.cuda()
            #print(inputs.size())
            #ssert 1==0

            prediction = net(inputs)

            loss = criterion(prediction, target)
            train_losses.update(loss.item(), inputs.size(0))

            prediction = prediction.detach().to(torch.device('cpu')).numpy()
            target = target.detach().to(torch.device('cpu')).numpy()

            for idx in range(inputs.size()[0]):
                img_prediction.append(prediction[idx, :])
                img_target.append(target[idx, :])

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if i % args.print_freq == 0 or i + 1 == len(train_loader):
                logger.log(
                    '[train Info]: [epoch-{}-{}][{:04d}/{:04d}][Loss:{:.2f}]'.
                    format(epoch, args.epochs, i, len(train_loader),
                           loss.item()))

        train_nme = compute_nme(args.num_pts, img_prediction, img_target)
        logger.log('epoch {:02d} completed!'.format(epoch))
        logger.log(
            '[train Info]: [epoch-{}-{}][Avg Loss:{:.6f}][NME:{:.2f}]'.format(
                epoch, args.epochs, train_losses.avg, train_nme * 100))
        tfboard_writer.add_scalar('Average Loss', train_losses.avg, epoch)
        tfboard_writer.add_scalar('NME', train_nme * 100,
                                  epoch)  # traing data nme

        # save checkpoint
        filename = 'epoch-{}-{}.pth'.format(epoch, args.epochs)
        save_path = logger.path('model') / filename
        torch.save(
            {
                'epoch': epoch,
                'args': deepcopy(args),
                'state_dict': net.state_dict(),
                'scheduler': scheduler.state_dict(),
                'optimizer': optimizer.state_dict(),
            },
            logger.path('model') / filename)
        logger.log('save checkpoint into {}'.format(filename))
        last_info = torch.save({
            'epoch': epoch,
            'last_checkpoint': save_path
        }, logger.last_info())

        # eval
        logger.log('Basic-Eval-All evaluates {} dataset'.format(
            len(eval_loaders)))

        for i, loader in enumerate(eval_loaders):

            eval_losses = AverageMeter()
            eval_prediction = []
            eval_target = []
            with torch.no_grad():
                net.eval()
                for i_batch, (inputs, target) in enumerate(loader):

                    target = target.squeeze(1)
                    inputs = inputs.cuda()
                    target = target.cuda()
                    prediction = net(inputs)
                    loss = criterion(prediction, target)
                    eval_losses.update(loss.item(), inputs.size(0))

                    prediction = prediction.detach().to(
                        torch.device('cpu')).numpy()
                    target = target.detach().to(torch.device('cpu')).numpy()

                    for idx in range(inputs.size()[0]):
                        eval_prediction.append(prediction[idx, :])
                        eval_target.append(target[idx, :])
                    if i_batch % args.print_freq == 0 or i + 1 == len(loader):
                        logger.log(
                            '[Eval Info]: [epoch-{}-{}][{:04d}/{:04d}][Loss:{:.2f}]'
                            .format(epoch, args.epochs, i, len(loader),
                                    loss.item()))

            eval_nme = compute_nme(args.num_pts, eval_prediction, eval_target)
            logger.log(
                '[Eval Info]: [evaluate the {}/{}-th dataset][epoch-{}-{}][Avg Loss:{:.6f}][NME:{:.2f}]'
                .format(i, len(eval_loaders), epoch, args.epochs,
                        eval_losses.avg, eval_nme * 100))
            tfboard_writer.add_scalar('eval_nme/{}'.format(i), eval_nme * 100,
                                      epoch)

    logger.close()
    os.makedirs(experiment_path)

if args.run:
    # Load params
    params = pickle.load(open(model_path + '/params.p', 'rb'))

    # Load model
    state_dict = torch.load('{}/model/{}_state_dict_best.pth'.format(
        model_path, params['model']),
                            map_location=lambda storage, loc: storage)
    model = load_model(params['model'], params)
    model.load_state_dict(state_dict)

    # Load ground-truth states from test set
    test_loader = DataLoader(GeneralDataset(params['dataset'],
                                            train=False,
                                            normalize_data=params['normalize'],
                                            subsample=params['subsample']),
                             batch_size=args.n_samples,
                             shuffle=args.shuffle)
    data, macro_intents = next(iter(test_loader))
    data, macro_intents = data.transpose(0, 1), macro_intents.transpose(0, 1)

    # Sample trajectories
    samples, macro_samples = model.sample(data,
                                          macro_intents,
                                          burn_in=args.burn_in)

    # Save samples
    samples = samples.detach().numpy()
    pickle.dump(samples,
                open(experiment_path + '/samples.p', 'wb'),
    model.load_state_dict(state_dict)
else:
    printlog('{:03d} {} {}'.format(args.trial, args.model, args.dataset))
    printlog(model.params_str)
    printlog(
        'start_lr {} | min_lr {} | subsample {} | batch_size {} | seed {}'.
        format(args.start_lr, args.min_lr, args.subsample, args.batch_size,
               args.seed))
    printlog('n_params: {:,}'.format(params['total_params']))
    printlog('best_loss:')
printlog('############################################################')

# Dataset loaders
kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
train_loader = DataLoader(GeneralDataset(args.dataset,
                                         train=True,
                                         normalize_data=args.normalize,
                                         subsample=args.subsample),
                          batch_size=batch_size,
                          shuffle=True,
                          **kwargs)
test_loader = DataLoader(GeneralDataset(args.dataset,
                                        train=False,
                                        normalize_data=args.normalize,
                                        subsample=args.subsample),
                         batch_size=batch_size,
                         shuffle=True,
                         **kwargs)

############################# TRAIN LOOP #############################

best_test_loss = 0
# Create save destination
save_path = 'datasets/{}/data/examples'.format(args.dataset)
if not os.path.exists(save_path):
    os.makedirs(save_path)

# Set params
params = {
    'dataset' : args.dataset,
    'normalize' : True,
    'n_samples' : args.n_samples,
    'burn_in' : 0,
    'genMacro' : True
}   

# Load ground-truth states from test set
test_loader = DataLoader(
    GeneralDataset(params['dataset'], train=False, normalize_data=params['normalize'], subsample=1), 
    batch_size=args.n_samples, shuffle=args.shuffle)
data, macro_intents = next(iter(test_loader))
data, macro_intents = data.detach().numpy(), macro_intents.detach().numpy()

# Get dataset plot function
dataset = import_module('datasets.{}'.format(params['dataset']))
plot_func = dataset.animate if args.animate else dataset.display

for k in range(args.n_samples):
    print('Sample {:02d}'.format(k))
    save_file = '{}/{:02d}'.format(save_path, k)
    plot_func(data[k], macro_intents[k], params=params, save_file=save_file)