cfg = configurations[args.config] cuda = torch.cuda.is_available """""" """""~~~ dataset loader ~~~""" """""" train_dataRoot = args.train_dataroot test_dataRoot = args.test_dataroot if not os.path.exists(args.snapshot_root): os.mkdir(args.snapshot_root) if not os.path.exists(args.salmap_root): os.mkdir(args.salmap_root) if args.phase == 'train': SnapRoot = args.snapshot_root # checkpoint train_loader = torch.utils.data.DataLoader(MyData(train_dataRoot, transform=True), batch_size=2, shuffle=True, num_workers=4, pin_memory=True) if args.phase == 'test': MapRoot = args.salmap_root test_loader = torch.utils.data.DataLoader(MyTestData(test_dataRoot, transform=True), batch_size=1, shuffle=True, num_workers=4, pin_memory=True) print('data already') """""" """"" ~~~nets~~~ """ """"""
""" Title: Depth-induced Multi-scale Recurrent Attention Network for Saliency Detection Author: Wei Ji, Jingjing Li E-mail: [email protected] """ import torch from torch.autograd import Variable from torch.utils.data import DataLoader import torchvision import torch.nn.functional as F import torch.optim as optim from dataset_loader import MyData, MyTestData from model import RGBNet,DepthNet from fusion import ConvLSTM from functions import imsave import argparse from trainer import Trainer import os configurations = { # same configuration as original work # https://github.com/shelhamer/fcn.berkeleyvision.org 1: dict( max_iteration=1000000, lr=1.0e-10, momentum=0.99, weight_decay=0.0005, spshot=20000, nclass=2, sshow=10,
import torch from torch.autograd import Variable from torch.utils.data import DataLoader import torchvision import torch.nn.functional as F import torch.optim as optim from dataset_loader import MyData, MyTestData, DTestData from model import FocalNet, FocalNet_sub from conv_lstm import ConvLSTM from functions import imsave import argparse from Trainer_Teacher import Trainer import os if __name__ == '__main__': configurations = { 1: dict( max_iteration=500000, lr=1.0e-10, momentum=0.99, weight_decay=0.0005, spshot=10000, nclass=2, sshow=10, focal_num=12, ) } parser=argparse.ArgumentParser() parser.add_argument('--phase', type=str, default='test', help='train or test') parser.add_argument('--param', type=str, default=True, help='path to pre-trained parameters')
import torch from torch.autograd import Variable from torch.utils.data import DataLoader import torchvision import torch.nn.functional as F import torch.optim as optim from dataset_loader import MyData, MyTestData from model import FocalNet, FocalNet_sub from conv_lstm import ConvLSTM from functions import imsave import argparse from Trainer_Student import Trainer from resnet_18 import Resnet_18 import os import imageio if __name__ == '__main__': configurations = { 1: dict( max_iteration=300000, lr=1.0e-10, momentum=0.99, weight_decay=0.0005, spshot=10000, nclass=2, sshow=10, focal_num=12, ) } parser=argparse.ArgumentParser()