Esempio n. 1
0
def test_content_loss_raises_if_wrong_reduction(x, y) -> None:
    for mode in ['mean', 'sum', 'none']:
        ContentLoss(reduction=mode)(x, y)

    for mode in [None, 'n', 2]:
        with pytest.raises(ValueError):
            ContentLoss(reduction=mode)(x, y)
Esempio n. 2
0
def test_content_loss_raises_if_wrong_reduction(prediction: torch.Tensor,
                                                target: torch.Tensor) -> None:
    for mode in ['mean', 'sum', 'none']:
        ContentLoss(reduction=mode)(prediction, target)

    for mode in [None, 'n', 2]:
        with pytest.raises(KeyError):
            ContentLoss(reduction=mode)(prediction, target)
Esempio n. 3
0
def test_content_loss_computes_grad(input_tensors: Tuple[torch.Tensor,
                                                         torch.Tensor],
                                    device: str) -> None:
    prediction, target = input_tensors
    prediction.requires_grad_()
    loss_value = ContentLoss()(prediction.to(device), target.to(device))
    loss_value.backward()
    assert prediction.grad is not None, NONE_GRAD_ERR_MSG
Esempio n. 4
0
def test_content_loss_computes_grad(input_tensors: Tuple[torch.Tensor,
                                                         torch.Tensor],
                                    device: str) -> None:
    x, y = input_tensors
    x.requires_grad_()
    loss_value = ContentLoss()(x.to(device), y.to(device))
    loss_value.backward()
    assert x.grad is not None, NONE_GRAD_ERR_MSG
Esempio n. 5
0
def test_content_loss_forward_for_special_cases(x, y, expectation: Any,
                                                value: float) -> None:
    loss = ContentLoss()
    with expectation:
        if value is None:
            loss(x, y)
        else:
            loss_value = loss(x, y)
            assert torch.isclose(loss_value, torch.tensor(value)), \
                f'Expected loss value to be equal to target value. Got {loss_value} and {value}'
Esempio n. 6
0
def test_content_loss_raises_if_layers_weights_mismatch(x, y) -> None:
    wrong_combinations = ({
        'layers': ['layer1'],
        'weights': [0.5, 0.5]
    }, {
        'layers': ['layer1', 'layer2'],
        'weights': [0.5]
    }, {
        'layers': ['layer1'],
        'weights': []
    })
    for combination in wrong_combinations:
        with pytest.raises(AssertionError):
            ContentLoss(**combination)
Esempio n. 7
0
def test_content_loss_doesnt_rise_if_layers_weights_mismatch_but_allowed(
        x, y) -> None:
    wrong_combinations = ({
        'layers': ['relu1_2'],
        'weights': [0.5, 0.5],
        'allow_layers_weights_mismatch': True
    }, {
        'layers': ['relu1_2', 'relu2_2'],
        'weights': [0.5],
        'allow_layers_weights_mismatch': True
    }, {
        'layers': ['relu2_2'],
        'weights': [],
        'allow_layers_weights_mismatch': True
    })
    for combination in wrong_combinations:
        ContentLoss(**combination)
Esempio n. 8
0
def main():
    """Parameters initialization and starting SG model training """
    # read command line arguments
    args = get_parser().parse_args()

    # set random seed
    seed_everything(args.seed)

    # paths to dataset
    train_path = osp.join(args.dataset_path, 'train')
    test_path = osp.join(args.dataset_path, 'test')

    # declare generator and discriminator models
    generator = Generator_with_Refin(args.encoder)
    discriminator = Discriminator(input_shape=(3,args.img_size,args.img_size))

    # load weights
    if args.gen_weights != '':
        generator.load_state_dict(torch.load(args.gen_weights))
        print('Generator weights loaded!')

    if args.discr_weights != '':
        discriminator.load_state_dict(torch.load(args.discr_weights))
        print('Discriminator weights loaded!')

    # declare datasets
    train_dataset = ARDataset(train_path,
                              augmentation=get_training_augmentation(args.img_size),
                              augmentation_images=get_image_augmentation(),
                              preprocessing=get_preprocessing(),)

    valid_dataset = ARDataset(test_path,
                              augmentation=get_validation_augmentation(args.img_size),
                              preprocessing=get_preprocessing(),)

    # declare loaders
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
    valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)

    # declare loss functions, optimizers and scheduler
    l2loss = nn.MSELoss()
    perloss = ContentLoss(feature_extractor="vgg16", layers=("relu3_3", ))
    GANloss = nn.MSELoss()

    optimizer_G = torch.optim.Adam([dict(params=generator.parameters(), lr=args.lr_G),])
    optimizer_D = torch.optim.Adam([dict(params=discriminator.parameters(), lr=args.lr_D),])

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_G, mode='min', factor=0.9, patience=args.patience)

    # device
    device = torch.device(args.device if torch.cuda.is_available() else 'cpu')

    # tensorboard
    writer = SummaryWriter()

    # start training
    train(
        generator=generator,
        discriminator=discriminator,
        device=device,
        n_epoch=args.n_epoch,
        optimizer_G=optimizer_G,
        optimizer_D=optimizer_D,
        train_loader=train_loader,
        valid_loader=valid_loader,
        scheduler=scheduler,
        losses=[l2loss, perloss, GANloss],
        models_paths=[args.Gmodel_path, args.Dmodel_path],
        bettas=[args.betta1, args.betta2, args.betta3],
        writer=writer,
    )
Esempio n. 9
0
def test_content_loss_supports_custom_extractor(x, y, device: str) -> None:
    loss = ContentLoss(feature_extractor=InceptionV3().blocks,
                       layers=['0', '1'],
                       weights=[0.5, 0.5])
    loss(x, y)
Esempio n. 10
0
def test_content_loss_replace_pooling(x, y, model: Union[str,
                                                         Callable]) -> None:
    ContentLoss(feature_extractor=model, replace_pooling=True)
Esempio n. 11
0
def test_content_loss_raises_if_wrong_extractor(x, y, model: Union[str,
                                                                   Callable],
                                                expectation: Any) -> None:
    with expectation:
        ContentLoss(feature_extractor=model)
Esempio n. 12
0
def test_content_loss_forward(input_tensors: Tuple[torch.Tensor, torch.Tensor],
                              device: str) -> None:
    prediction, target = input_tensors
    loss = ContentLoss()
    loss(prediction.to(device), target.to(device))
Esempio n. 13
0
def test_content_loss_raises_if_wrong_extractor(prediction: torch.Tensor,
                                                target: torch.Tensor,
                                                model: Union[str, Callable],
                                                expectation: Any) -> None:
    with expectation:
        ContentLoss(feature_extractor=model)
Esempio n. 14
0
def test_content_loss_forward(input_tensors: Tuple[torch.Tensor, torch.Tensor],
                              device: str) -> None:
    x, y = input_tensors
    loss = ContentLoss()
    loss(x.to(device), y.to(device))
Esempio n. 15
0
def test_content_loss_init() -> None:
    ContentLoss()
Esempio n. 16
0
def test_content_loss_forward_for_normalized_input(device: str) -> None:
    prediction = torch.randn(2, 3, 96, 96).to(device)
    target = torch.randn(2, 3, 96, 96).to(device)
    loss = ContentLoss(mean=[0., 0., 0.], std=[1., 1., 1.])
    loss(prediction.to(device), target.to(device))
Esempio n. 17
0
def test_content_loss_forward_for_normalized_input(device: str) -> None:
    x = torch.randn(2, 3, 96, 96).to(device)
    y = torch.randn(2, 3, 96, 96).to(device)
    loss = ContentLoss(mean=[0., 0., 0.], std=[1., 1., 1.])
    loss(x.to(device), y.to(device))
Esempio n. 18
0
def test_content_loss_supports_custom_extractor(prediction: torch.Tensor,
                                                target: torch.Tensor,
                                                device: str) -> None:
    loss = ContentLoss(feature_extractor=InceptionV3().blocks,
                       layers=['0', '1'])
    loss(prediction, target)
Esempio n. 19
0
def test_content_loss_replace_pooling(prediction: torch.Tensor,
                                      target: torch.Tensor,
                                      model: Union[str, Callable]) -> None:
    ContentLoss(feature_extractor=model, replace_pooling=True)