Пример #1
0
cfg = configurations[args.config]

cuda = torch.cuda.is_available
"""""" """""~~~ dataset loader ~~~""" """"""

train_dataRoot = args.train_dataroot
test_dataRoot = args.test_dataroot

if not os.path.exists(args.snapshot_root):
    os.mkdir(args.snapshot_root)
if not os.path.exists(args.salmap_root):
    os.mkdir(args.salmap_root)

if args.phase == 'train':
    SnapRoot = args.snapshot_root  # checkpoint
    train_loader = torch.utils.data.DataLoader(MyData(train_dataRoot,
                                                      transform=True),
                                               batch_size=2,
                                               shuffle=True,
                                               num_workers=4,
                                               pin_memory=True)

if args.phase == 'test':
    MapRoot = args.salmap_root
    test_loader = torch.utils.data.DataLoader(MyTestData(test_dataRoot,
                                                         transform=True),
                                              batch_size=1,
                                              shuffle=True,
                                              num_workers=4,
                                              pin_memory=True)
print('data already')
"""""" """"" ~~~nets~~~ """ """"""
Пример #2
0
"""
Title: Depth-induced Multi-scale Recurrent Attention Network for Saliency Detection
Author: Wei Ji, Jingjing Li
E-mail: [email protected]
"""
import torch
from torch.autograd import Variable
from torch.utils.data import DataLoader
import torchvision
import torch.nn.functional as F
import torch.optim as optim
from dataset_loader import MyData, MyTestData
from model import RGBNet,DepthNet
from fusion import ConvLSTM
from functions import imsave
import argparse
from trainer import Trainer
import os

configurations = {
    # same configuration as original work
    # https://github.com/shelhamer/fcn.berkeleyvision.org
    1: dict(
        max_iteration=1000000,
        lr=1.0e-10,
        momentum=0.99,
        weight_decay=0.0005,
        spshot=20000,
        nclass=2,
        sshow=10,
Пример #3
0
import torch
from torch.autograd import Variable
from torch.utils.data import DataLoader
import torchvision
import torch.nn.functional as F
import torch.optim as optim
from dataset_loader import MyData, MyTestData, DTestData
from model import FocalNet, FocalNet_sub
from conv_lstm import ConvLSTM
from functions import imsave
import argparse
from Trainer_Teacher import Trainer
import os

if __name__ == '__main__':
    configurations = {
        1: dict(
            max_iteration=500000,
            lr=1.0e-10,
            momentum=0.99,
            weight_decay=0.0005,
            spshot=10000,
            nclass=2,
            sshow=10,
            focal_num=12,
        )
    }
    parser=argparse.ArgumentParser()
    parser.add_argument('--phase', type=str, default='test', help='train or test')
    parser.add_argument('--param', type=str, default=True, help='path to pre-trained parameters')
Пример #4
0
import torch
from torch.autograd import Variable
from torch.utils.data import DataLoader
import torchvision
import torch.nn.functional as F
import torch.optim as optim
from dataset_loader import MyData, MyTestData
from model import FocalNet, FocalNet_sub
from conv_lstm import ConvLSTM
from functions import imsave
import argparse
from Trainer_Student import Trainer
from resnet_18 import Resnet_18
import os
import imageio

if __name__ == '__main__':
    configurations = {
        1: dict(
            max_iteration=300000,
            lr=1.0e-10,
            momentum=0.99,
            weight_decay=0.0005,
            spshot=10000,
            nclass=2,
            sshow=10,
            focal_num=12,
        )
    }
    parser=argparse.ArgumentParser()