def main(): global args args = parser.parse_args() os.makedirs(args.output, exist_ok=True) # if don't call torch.cuda.current_device(), fails later with # "RuntimeError: cuda runtime error (30) : unknown error at ..\aten\src\THC\THCGeneral.cpp:87" torch.cuda.current_device() use_cuda = torch.cuda.is_available() and True device = torch.device("cuda:0" if use_cuda else "cpu") # try to get consistent results across runs # => currently still fails, however, makes runs a bit more consistent _set_random_seed() # create model model = PoseNet(arch=args.arch, num_features=args.features, dropout=args.dropout, pretrained=True, cache_dir=args.cache, loss=args.loss, excl_bn_affine=args.excl_bn, beta=args.beta, sx=args.sx, sq=args.sq) # create optimizer # - currently only Adam supported if args.optimizer == 'adam': eps = 0.1 if args.split_opt_params: new_biases, new_weights, biases, weights, others = model.params_to_optimize(split=True, excl_batch_norm=args.excl_bn) optimizer = torch.optim.Adam([ {'params': new_biases, 'lr': args.lr * 2, 'weight_decay': 0.0, 'eps': eps}, {'params': new_weights, 'lr': args.lr, 'weight_decay': args.weight_decay, 'eps': eps}, {'params': biases, 'lr': args.lr * 2, 'weight_decay': 0.0, 'eps': eps}, {'params': weights, 'lr': args.lr, 'weight_decay': args.weight_decay, 'eps': eps}, {'params': others, 'lr': 0, 'weight_decay': 0, 'eps': eps}, ]) else: params = model.params_to_optimize(excl_batch_norm=args.excl_bn) optimizer = torch.optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay, eps=eps) else: assert False, 'Invalid optimizer: %s' % args.optimizer # optionally resume from a checkpoint best_loss = float('inf') best_epoch = -1 if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_epoch = checkpoint['best_epoch'] best_loss = checkpoint['best_loss'] model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) quit() # define overall training dataset, set output normalization, load model to gpu all_tr_data = PoseDataset(args.data, 'dataset_train.txt', random_crop=not args.center_crop) model.set_target_transform(all_tr_data.target_mean, all_tr_data.target_std) model.to(device) # split overall training data to training and validation sets # validation set is used for early stopping, or possibly in future for hyper parameter optimization lengths = [round(len(all_tr_data) * 0.75), round(len(all_tr_data) * 0.25)] tr_data, val_data = torch.utils.data.random_split(all_tr_data, lengths) # define data loaders train_loader = DataLoader(tr_data, batch_size=args.batch_size, num_workers=args.workers, shuffle=True, pin_memory=True, worker_init_fn=_worker_init_fn) val_loader = DataLoader(val_data, batch_size=args.batch_size, num_workers=args.workers, shuffle=False, pin_memory=True, worker_init_fn=_worker_init_fn) test_loader = DataLoader(PoseDataset(args.data, 'dataset_test.txt', random_crop=False), batch_size=args.batch_size, num_workers=args.workers, shuffle=False, pin_memory=True, worker_init_fn=_worker_init_fn) # evaluate model only if args.evaluate: validate(test_loader, model) return # training loop for epoch in range(args.start_epoch, args.epochs): # train for one epoch lss, pos, ori = process(train_loader, model, optimizer, epoch, device, adv_tr_eps=args.adv_tr_eps) stats = np.zeros(16) stats[:6] = [epoch, lss.avg, pos.avg, pos.median, ori.avg, ori.median] # evaluate on validation set if (epoch+1) % args.test_freq == 0: lss, pos, ori = validate(val_loader, model, device) stats[6:11] = [lss.avg, pos.avg, pos.median, ori.avg, ori.median] # remember best loss and save checkpoint is_best = lss.avg < best_loss best_epoch = epoch if is_best else best_epoch best_loss = lss.avg if is_best else best_loss # save best model if is_best: _save_checkpoint({ 'epoch': epoch + 1, 'best_epoch': best_epoch, 'best_loss': best_loss, 'arch': args.arch, 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), }, True) else: is_best = False # maybe save a checkpoint even if not best model if (epoch+1) % args.save_freq == 0 and not is_best: _save_checkpoint({ 'epoch': epoch + 1, 'best_epoch': best_epoch, 'best_loss': best_loss, 'arch': args.arch, 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), }, False) # evaluate on test set if best yet result on validation set if is_best: lss, pos, ori = validate(test_loader, model, device) stats[11:] = [lss.avg, pos.avg, pos.median, ori.avg, ori.median] # add row to log file _save_log(stats, epoch == 0) # early stopping if args.early_stopping > 0 and epoch - best_epoch >= args.early_stopping: print('=====\nEARLY STOPPING CRITERION MET (%d epochs since best validation loss)' % args.early_stopping) break print('=====\n') if epoch+1 == args.epochs: print('MAX EPOCHS (%d) REACHED' % args.epochs) print('BEST VALIDATION LOSS: %.3f' % best_loss)
print('Train size: {} x {}'.format(len(train_data), train_data[0].size())) ## LOAD MODEL print('\nLOADING GAN.') def weights_init(m): if type(m) == torch.nn.Linear: torch.nn.init.xavier_uniform_(m.weight) torch.nn.init.constant_(m.bias, 0.0) netG = PoseNet(n_hidden=N_HIDDEN, mode='generator').to(device) netD = PoseNet(n_hidden=N_HIDDEN, mode='discriminator').to(device) if args.model: netG.load_state_dict(torch.load(args.model)['netG']) netD.load_state_dict(torch.load(args.model)['netD']) print('=> Loaded models from {:s}'.format(args.model)) else: netG.apply(weights_init) netD.apply(weights_init) print('Model params: {:.2f}M'.format( sum(p.numel() for p in netG.parameters()) / 1e6)) ## TRAINING print('\nTRAINING.') data_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, **kwargs)