def main(): model = PSMNet(args.maxdisp, args.mindisp).cuda() if args.load_model is not None: if args.load is not None: warn('args.load is not None. load_model will be covered by load.') ckpt = torch.load(args.load_model, 'cpu') if 'model' in ckpt.keys(): pretrained = ckpt['model'] elif 'state_dict' in ckpt.keys(): pretrained = ckpt['state_dict'] else: raise RuntimeError() pretrained = { k.replace('module.', ''): v for k, v in pretrained.items() } model.load_state_dict(pretrained) train_dl = DataLoader(KITTIRoiDataset(args.data_dir, 'train', args.resolution, args.maxdisp, args.mindisp), batch_size=args.batch_size, shuffle=True, num_workers=args.workers) val_dl = DataLoader(KITTIRoiDataset(args.data_dir, 'val', args.resolution, args.maxdisp, args.mindisp), batch_size=args.batch_size, num_workers=args.workers) loss_fn = PSMLoss() databunch = DataBunch(train_dl, val_dl, device='cuda') learner = Learner(databunch, model, loss_func=loss_fn, model_dir=args.model_dir) learner.callbacks = [ DistributedSaveModelCallback(learner), TensorBoardCallback(learner) ] if num_gpus > 1: learner.to_distributed(get_rank()) if args.load is not None: learner.load(args.load) if args.mode == 'train': learner.fit(args.epochs, args.maxlr) elif args.mode == 'train_oc': fit_one_cycle(learner, args.epochs, args.maxlr) else: raise ValueError('args.mode not supported.')
lr=(args.lr, args.lr / args.lr_div), mom=(args.mom, 0.95), cycle_len=args.cycle_len, cycle_mult=args.cycle_mult, start_epoch=args.start_epoch) ] learn = Learner(db, model, metrics=[rmse, mae], callback_fns=callback_fns, wd=args.wd, loss_func=contribs_rmse_loss) if args.start_epoch > 0: learn.load(model_se_str + f'_{args.start_epoch-1}') else: learn.load(model_str) torch.cuda.empty_cache() if distributed_train: learn = learn.to_distributed(args.local_rank) learn.fit(args.epochs) # make predictions n_val = len(train_df[train_df['molecule_id'].isin(val_mol_ids)]) val_preds = np.zeros((n_val, args.epochs)) test_preds = np.zeros((len(test_df), args.epochs)) for m in range(args.epochs): print(f'Predicting for model {m}') learn.load(model_se_str + f'_{m}') val_contrib_preds = learn.get_preds(DatasetType.Valid) test_contrib_preds = learn.get_preds(DatasetType.Test) val_preds[:, m] = val_contrib_preds[0][:, -1].detach().numpy() test_preds[:, m] = test_contrib_preds[0][:, -1].detach().numpy() val_preds = val_preds * C.SC_STD + C.SC_MEAN