def main(args): parser = argparse.ArgumentParser( description=("Run deep models for visual semantic role segmentation " "(or detection)")) parser.add_argument("mode", help="Mode to run model in (e.g. 'train')") parser.add_argument("-s", "--save_dir", help="directory for saving the model", default="saved_models/%s" % dt.datetime.now().strftime("%Y_%m_%d_%H_%M_%S")) parser.add_argument("-e", "--epochs", help="number of epochs for training", type=int, default=50) parser.add_argument("-p", "--save_per", help="epochs to wait before saving", type=int, default=5) parser.add_argument("-l", "--learn_rate", help="learning rate", type=float, default=0.001) parser.add_argument("-c", "--cuda", type=int, nargs="+", help="ids of gpus to use during training", default=[]) parser.add_argument("-f", "--fake", action="store_true", help=("flag to use fake data that loads quickly (for" "development purposes)")) parser.add_argument( "--net", help="file in which model is stored. Used in test mode.", default=None) cfg = parser.parse_args(args) if cfg.mode == 'train': model = md.CtxBB() if cfg.fake: dataloader = get_fake_loader() else: dataloader = ld.get_loader("vcoco_train", ld.COCO_IMGDIR) trainer = md.BasicTrainer(model, dataloader, **vars(cfg)) logging.getLogger(__name__).info("Beginning Training...") trainer.train(cfg.epochs) elif cfg.mode == 'test': checkpoint = torch.load(cfg.net) model = checkpoint["model"] evaluator = ev.Evaluator(**vars(cfg)) ev.do_eval(evaluator, model, "vcoco_val", cfg.save_dir) else: logging.getLogger(__name__).error("Invalid mode '%s'" % str(cfg.mode)) sys.exit(1)
def test_pickle(self): self.trainer = model.BasicTrainer( self.model, self.dataloader, save_dir=os.path.join(self.test_dir, "test_resume"), save_per=4, cuda=[0]) outname = os.path.join(self.test_dir, "test_pickle.trn") torch.save(self.trainer, outname)
def test_prod_train(self): self.model = model.TestCtxBB() self.trainer = model.BasicTrainer( self.model, self.dataloader, save_dir=os.path.join(self.test_dir, "test_resume"), save_per=4, cuda=[0]) outname = os.path.join(self.test_dir, "test_pickle.trn") self.trainer.train(1)
def test_save(self): self.trainer = model.BasicTrainer( self.model, self.dataloader, cuda=None, save_dir=os.path.join(self.test_dir, "test_save")) # Make sure training epoch is maintained as part of state. self.trainer.train(2) outname = os.path.join(self.test_dir, "resume_train.trn") torch.save(self.trainer, outname) loaded = torch.load(outname) self.assertEqual(loaded.epoch, 2)
def test_resume_train(self): self.trainer = model.BasicTrainer( self.model, self.dataloader, cuda=None, save_dir=os.path.join(self.test_dir, "test_resume"), save_per=4) # Make sure epochs are consistent when training pauses. self.trainer.train(2) self.assertEqual(self.trainer.epoch, 2) self.trainer.train(6) self.assertEqual(self.trainer.epoch, 8) self.trainer.train(0) self.assertEqual(self.trainer.epoch, 8) self.trainer.train(3) self.assertEqual(self.trainer.epoch, 11)