def train_sfpn_small_balance(depth=50, roiSize=224): train_dataset = sdataset.SmallPatchDataset("train", roiSize, balance=True) val_dataset = sdataset.SmallPatchDataset("val", roiSize) test_dataset = sdataset.SmallPatchDataset("test", roiSize) from sfpn import SFPN model = nn.DataParallel(SFPN().cuda()) dcfg = {"bsize": 192, "nworker": 20, "collate": default_collate} model_name = "sp%d_sfpn%d_small_balance" % (roiSize, depth) train(model, model_name, train_dataset, val_dataset, test_dataset, dcfg)
def train_srdn_balance(roiSize=32): train_dataset = sdataset.SmallPatchDataset("train", roiSize, balance=True) val_dataset = sdataset.SmallPatchDataset("val", roiSize) test_dataset = sdataset.SmallPatchDataset("test", roiSize) from srdn import RDN # model = nn.DataParallel( # RDN(g0=32, d=4, c=6, k=16, roiSize=roiSize).cuda()) model = nn.DataParallel(RDN(g0=16, d=3, c=4, k=16, roiSize=roiSize).cuda()) dcfg = {"bsize": 4096, "nworker": 20, "collate": default_collate} model_name = "sp%d_rdn_small_balance_d3c4k16" % roiSize train(model, model_name, train_dataset, val_dataset, test_dataset, dcfg)
def train_sfpn_small_aug(depth=50, roiSize=224): top = torchvision.transforms.ToPILImage() hf = torchvision.transforms.RandomHorizontalFlip() vf = torchvision.transforms.RandomVerticalFlip() rot = torchvision.transforms.RandomRotation(30) size = torchvision.transforms.Resize((roiSize, roiSize)) tot = torchvision.transforms.ToTensor() trfm = torchvision.transforms.Compose([top, hf, vf, rot, size, tot]) train_dataset = sdataset.SmallPatchDataset("train", roiSize=roiSize, transform=trfm) val_dataset = sdataset.SmallPatchDataset("val", roiSize=roiSize) test_dataset = sdataset.SmallPatchDataset("test", roiSize=roiSize) from sfpn import SFPN model = nn.DataParallel(SFPN().cuda()) # dcfg = {"bsize": 192, "nworker": 20, "collate": default_collate} dcfg = {"bsize": 128, "nworker": 20, "collate": default_collate} model_name = "sp%d_sfpn%d_small_aug" % (roiSize, depth) train(model, model_name, train_dataset, val_dataset, test_dataset, dcfg)