def fit(self, learner: Learner, weight_decay): if self.early_stop: learner.callbacks.append( EarlyStoppingCallback(learner, patience=self.early_stop.patience)) fit_one_cycle(learner, cyc_len=self.cyc_len, tot_epochs=self.max_epochs, max_lr=self.max_lr, wd=weight_decay)
def main(): model = PSMNet(args.maxdisp, args.mindisp).cuda() if args.load_model is not None: if args.load is not None: warn('args.load is not None. load_model will be covered by load.') ckpt = torch.load(args.load_model, 'cpu') if 'model' in ckpt.keys(): pretrained = ckpt['model'] elif 'state_dict' in ckpt.keys(): pretrained = ckpt['state_dict'] else: raise RuntimeError() pretrained = { k.replace('module.', ''): v for k, v in pretrained.items() } model.load_state_dict(pretrained) train_dl = DataLoader(KITTIRoiDataset(args.data_dir, 'train', args.resolution, args.maxdisp, args.mindisp), batch_size=args.batch_size, shuffle=True, num_workers=args.workers) val_dl = DataLoader(KITTIRoiDataset(args.data_dir, 'val', args.resolution, args.maxdisp, args.mindisp), batch_size=args.batch_size, num_workers=args.workers) loss_fn = PSMLoss() databunch = DataBunch(train_dl, val_dl, device='cuda') learner = Learner(databunch, model, loss_func=loss_fn, model_dir=args.model_dir) learner.callbacks = [ DistributedSaveModelCallback(learner), TensorBoardCallback(learner) ] if num_gpus > 1: learner.to_distributed(get_rank()) if args.load is not None: learner.load(args.load) if args.mode == 'train': learner.fit(args.epochs, args.maxlr) elif args.mode == 'train_oc': fit_one_cycle(learner, args.epochs, args.maxlr) else: raise ValueError('args.mode not supported.')
def main(): model = StegNet(10, 6) print("Created Model") if args.train: data_train = ImageLoader(args.datapath + '/train', args.num_train, args.fourierSeed, args.size, args.bs) data_val = ImageLoader(args.datapath + '/val', args.num_val, args.fourierSeed, args.size, args.bs) data = DataBunch(data_train, data_val) print("Loaded DataSets") if args.model is not None: model.load_state_dict(torch.load(args.model)) print("Loaded pretrained model") loss_fn = mse learn = Learner(data, model, loss_func=loss_fn, metrics=[mse_cov, mse_hidden]) print("training") fit_one_cycle(learn, args.epochs, 1e-2) torch.save(learn.model.state_dict(), "model.pth") print("model saved") else: path = input( "Enter path of the model: ") if args.model is None else args.model model.load_state_dict(torch.load(args.model)) model.eval() if args.encode: f_paths = [ args.datapath + '/cover/' + f for f in os.listdir(args.datapath + '/cover') ] try: os.mkdir(args.datapath + '/encoded') except OSError: pass fourier_func = partial(encrypt, seed=args.fourierSeed) encode_partial = partial(encode, model=model.encoder, size=args.size, fourier_func=fourier_func) parallel(encode_partial, f_paths) else: f_paths = [ args.datapath + '/encoded/' + f for f in os.listdir(args.datapath + '/encoded') ] try: os.mkdir(args.datapath + '/decoded') except OSError: pass fourier_func = partial(decrypt, seed=args.fourierSeed) decode_partial = partial(decode, model=model.decoder, size=args.size, fourier_func=fourier_func) parallel(decode_partial, f_paths)