def main(): args = parser.parse_args() args.cuda = not args.no_cuda and torch.cuda.is_available() if args.savemodel == "": timestr = datetime.now().isoformat().replace(':', '-').replace('.', 'MS') args.savemodel = timestr savepath = os.path.join(args.savedir, args.savemodel) if not os.path.exists(savepath): os.makedirs(savepath) log_file = os.path.join(savepath, 'run.log') logger = logging.getLogger('FS') logger.setLevel(logging.INFO) formatter = logging.Formatter( '%(asctime)s - %(filename)s - %(lineno)s: %(message)s') fh = logging.StreamHandler(sys.stderr) fh.setLevel(logging.INFO) fh.setFormatter(formatter) logger.addHandler(fh) fh = logging.FileHandler(log_file) fh.setLevel(logging.INFO) fh.setFormatter(formatter) logger.addHandler(fh) for k, v in sorted(vars(args).items()): logger.info('%s - %s' % (k, v)) if args.seed != 0: torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) if args.dataset == "flow": trli, trri, trld, teli, teri, teld = lt.list_flow_file(args.datapath) elif args.dataset == "kitti": trli, trri, trld, teli, teri, teld = lt.list_kitti_file( args.datapath, args.date) elif args.dataset == "middlebury": trli, trri, trld, teli, teri, teld = lt.list_middlebury_file( args.datapath) TimeImgLoader = torch.utils.data.DataLoader(DA.ImageFloder(teli * 16, teri * 16, teld * 16, training=False, args=args), batch_size=args.batch_size, shuffle=False, num_workers=5, drop_last=False) model = get_model(args) if args.cuda: model.cuda() if args.loadmodel is not None: state_dict = torch.load( os.path.join(args.savedir, args.loadmodel, "max_loss.tar")) model.load_state_dict(state_dict['state_dict']) logger.info('Number of model parameters: {}'.format( sum([p.data.nelement() for p in model.parameters()]))) ## Timing ## start_time = time.time() total_time = 0. for batch_idx, (imgL, imgR, disp_L) in enumerate(TimeImgLoader): per_time = runtime(model, args, imgL, imgR, disp_L) total_time += per_time logger.info( 'total test time = %.5f, per example time = %.5f' % (time.time() - start_time, total_time / len(TimeImgLoader.dataset)))
def main(): args = parser.parse_args() args.cuda = not args.no_cuda and torch.cuda.is_available() if args.savemodel == "": timestr = datetime.now().isoformat().replace(':','-').replace('.','MS') args.savemodel = timestr savepath = os.path.join(args.savedir, args.savemodel) if not os.path.exists(savepath): os.makedirs(savepath) log_file = os.path.join(savepath, 'run.log') logger = logging.getLogger('FS') logger.setLevel(logging.INFO) formatter = logging.Formatter('%(asctime)s - %(filename)s - %(lineno)s: %(message)s') fh = logging.StreamHandler(sys.stderr) fh.setLevel(logging.INFO) fh.setFormatter(formatter) logger.addHandler(fh) fh = logging.FileHandler(log_file) fh.setLevel(logging.INFO) fh.setFormatter(formatter) logger.addHandler(fh) for k,v in sorted(vars(args).items()): logger.info('%s - %s' % (k, v)) if args.time_out: logger.info('start sleeping') time.sleep(3600) logger.info('end sleeping') if args.seed != 0: torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) if args.dataset == "flow": trli, trri, trld, teli, teri, teld = lt.list_flow_file(args.datapath) elif args.dataset == "kitti": if args.all_train: trli, trri, trld, teli, teri, teld = lt.list_kitti_file(os.path.join(args.datapath, 'kitti2015','training'), '2015') trli2, trri2, trld2, teli2, teri2, teld2 = lt.list_kitti_file(os.path.join(args.datapath, 'kitti2012','train_194'), '2012') trli.extend(trli2) trri.extend(trri2) trld.extend(trld2) teli.extend(teli2) teri.extend(teri2) teld.extend(teld2) else: trli, trri, trld, teli, teri, teld = lt.list_kitti_file(args.datapath, args.date) elif args.dataset == "middlebury": trli, trri, trld, teli, teri, teld = lt.list_middlebury_file(args.datapath) if args.all_train: trli.extend(teli) trri.extend(teri) trld.extend(teld) TrainImgLoader = torch.utils.data.DataLoader( DA.ImageFloder(trli, trri, trld, training=args.no_train_aug, args=args), batch_size=args.batch_size, shuffle=True, num_workers=5, drop_last=True) TestImgLoader = torch.utils.data.DataLoader( DA.ImageFloder(teli, teri, teld, training=False, args=args), batch_size=args.batch_size, shuffle=False, num_workers=5, drop_last=False) model = get_model(args) if args.cuda: model = nn.DataParallel(model) model.cuda() if args.loadmodel is not None: state_dict = torch.load(os.path.join(args.savedir, args.loadmodel, "max_loss.tar")) model.load_state_dict(state_dict['state_dict']) logger.info('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()]))) if args.optimizer == 'adam': optimizer = optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay) elif args.optimizer == 'mom': optimizer = optim.SGD(model.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=args.weight_decay) elif args.optimizer == 'sgd': optimizer = optim.SGD(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.decay_epochs) max_loss=1e10 max_epo=0 loss_avg = 0. for epoch in range(1, args.epochs+1): logger.info('This is %d-th epoch' %(epoch)) total_train_loss = 0 scheduler.step() ## training ## start_time = time.time() for batch_idx, (imgL_crop, imgR_crop, disp_crop_L) in enumerate(TrainImgLoader): loss = train(model, optimizer, args, imgL_crop,imgR_crop, disp_crop_L, epoch=epoch) loss_avg = 0.99 * loss_avg + 0.01 * loss if (batch_idx + 1) % args.log_steps == 0: logger.info('Iter %d training loss = %.3f , time = %.2f' %( batch_idx + 1, loss_avg, time.time() - start_time)) start_time = time.time() total_train_loss += loss logger.info('epoch %d total training loss = %.3f' %(epoch, total_train_loss/len(TrainImgLoader))) ## TEST ## start_time = time.time() total_test_loss = 0. total_test_num = 0. for batch_idx, (imgL, imgR, disp_L) in enumerate(TestImgLoader): loss, num = test(model, args, imgL,imgR, disp_L) if (batch_idx + 1) % args.log_steps == 0: logger.info('Iter %d test loss = %.5f , time = %.2f' %( batch_idx + 1, loss/num, time.time() - start_time)) start_time = time.time() total_test_loss += loss total_test_num += num if args.dataset == "kitti" or args.dataset == "middlebury": total_test_loss = (1 - total_test_loss / total_test_num) * 100. else: total_test_loss = total_test_loss / total_test_num logger.info('total test loss = %.5f' % total_test_loss) if total_test_loss < max_loss: max_loss = total_test_loss max_epo = epoch savefilename = os.path.join(savepath, 'max_loss.tar') torch.save({ 'epoch': epoch, 'state_dict': model.state_dict(), 'test_loss': total_test_loss, }, savefilename) logger.info('MAX epoch %d total test error = %.5f' %(max_epo, max_loss))