Ejemplo n.º 1
0
def test():
    # 初始化模型,进行测试
    n2n = Noise2Noise(opt_test, trainable=False)
    opt_test.redux = False
    test_loader = load_dataset(opt_test.data,
                               3,
                               opt_test,
                               shuffled=False,
                               single=True)  #修改0
    n2n.load_model(opt_test.load_ckpt)  #加载预训练模型
    n2n.test(test_loader)
Ejemplo n.º 2
0
def train():

    # 加载训练集和验证集
    train_loader = load_dataset(opt_train.train_dir,
                                opt_train.train_size,
                                opt_train,
                                shuffled=True)
    valid_loader = load_dataset(opt_train.valid_dir,
                                opt_train.valid_size,
                                opt_train,
                                shuffled=False)

    # 初始化模型并训练
    n2n = Noise2Noise(opt_train, trainable=True)
    n2n.train(train_loader, valid_loader)
Ejemplo n.º 3
0
def main():
    """Tests Noise2Noise."""
    os.environ["CUDA_VISIBLE_DEVICES"] = '0'
    # Parse test parameters
    params = parse_args()

    # Initialize model and test
    n2n = Noise2Noise(params,
                      trainable=False,
                      pretrain_model_path=params.pretrain_model_path)
    params.redux = False
    params.clean_targets = True
    test_loader = load_dataset(params.data,
                               0,
                               params,
                               shuffled=False,
                               single=True)
    # n2n.load_model(params.load_ckpt)
    n2n.test(test_loader, show=params.show_output)
Ejemplo n.º 4
0
    parser.add_argument('-s', '--seed', help='fix random seed', type=int)
    parser.add_argument('-c',
                        '--crop-size',
                        help='image crop size',
                        default=256,
                        type=int)
    parser.add_argument('-t', '--test2', help='test2', action='store_true')
    return parser.parse_args()


if __name__ == '__main__':
    """Tests Noise2Noise."""

    # Parse test parameters
    params = parse_args()

    # Initialize model and test
    n2n = Noise2Noise(params, trainable=False)
    params.redux = False
    params.clean_targets = True
    n2n.load_model(params.load_ckpt)
    if not params.test2:
        test_loader = load_dataset(params.data,
                                   0,
                                   params,
                                   shuffled=False,
                                   single=True)
        n2n.test(test_loader, show=params.show_output)
    else:
        n2n.test2()
Ejemplo n.º 5
0
                        '--crop-size',
                        help='random crop size',
                        default=128,
                        type=int)
    parser.add_argument('--clean-targets',
                        help='use clean targets for training',
                        action='store_true')

    return parser.parse_args()


if __name__ == '__main__':
    """Trains Noise2Noise."""

    # Parse training parameters
    params = parse_args()

    # Train/valid datasets
    train_loader = load_dataset(params.train_dir,
                                params.train_size,
                                params,
                                shuffled=True)
    valid_loader = load_dataset(params.valid_dir,
                                params.valid_size,
                                params,
                                shuffled=False)

    # Initialize model and train
    n2n = Noise2Noise(params, trainable=True)
    n2n.train(train_loader, valid_loader)
Ejemplo n.º 6
0
device = torch.cuda.device("cuda:1" if torch.cuda.is_available() else "cpu")
print(device)

# TODO
# validation batch_size,
# true_image dataset extraction loop
# image_summary [index] : noise2noise.py : summary.py
config = {
    'device': device,
    'dataset': 'gaussian',
    'batch_size_train': 4,
    'batch_size_valid': 1,
    'num_workers': 4,
    'learning_rate': 0.001,
    'checkpoint_dir': './models_trained/.',
    'checkpoint_filename': 'latest.pt',
    'root_dir': '/home/turing/Documents/BE/',
    'log_dir': './log/',
    'data_folder': './data/',
    'num_epochs': 300,
}

trainloader = datasetLoader(config, 'train')
validloader = datasetLoader(config, 'valid')

NET = Noise2Noise(config, trainloader, validloader, trainable=True)

NET.load_model(config)

NET.train(config)