コード例 #1
0
def test_exclude_discriminators(basic_gan):
	opt = standard_options()
	opt.num_discrims = 0

	geo = Pix2PixGeoModel()
	geo.initialize(opt)

	assert(len(geo.netD1s) == 0)

	geo.initialize(opt)
	opt.num_discrims = 2

	geo = Pix2PixGeoModel()
	geo.initialize(opt)

	assert(len(geo.netD1s) == 2)
コード例 #2
0
def test_correct_shape_for_folder_id(mocker):
	# MockGenerator = basic_gan

	mocker.patch('torch.optim')
	mocker.patch('models.networks.get_scheduler')
	mocker.patch('models.networks.define_D', side_effect=mocker.MagicMock())
	mocker.patch('models.networks.define_G', return_value=UnetGenerator(3, 3, 7, 4))

	gan = Pix2PixGeoModel()

	opt = standard_options()
	opt.num_discrims = 1
	opt.low_iter = 1
	opt.high_iter = 1

	opt.which_model_netD = 'wgan-gp'
	opt.lambda_A = 1
	opt.lambda_B = 1
	opt.lambda_C = 1
	opt.discrete_only = False
	opt.local_loss = False
	opt.fineSize = 256
	opt.ngf = 4
	opt.num_folders = 20
	opt.folder_pred = True

	gan.initialize(opt)

	x = torch.rand((2, 3, 256, 512))

	y = gan.netG(x)

	folder = gan.folder_fc(gan.netG.inner_layer.output.view(2, -1))

	assert(folder.shape == (2, 20))
コード例 #3
0
def test_generator_w_continuous(basic_gan):
	MockGenerator = basic_gan

	gan = Pix2PixGeoModel()

	opt = standard_options()
	opt.discrete_only = False

	gan.initialize(opt)

	fake_dataset = DatasetMock()

	gan.netG_DIV = basic_gan.mocker.MagicMock()
	gan.netG_Vx = basic_gan.mocker.MagicMock()
	gan.netG_Vy = basic_gan.mocker.MagicMock()

	gan.set_input(fake_dataset)
	gan.forward()

	# call args structure is ((list of args as tuple), {list of keyword args})
	# So [0][0] accesses first non-keyword arg
	assert(gan.netG.call_args[0][0].name == ['A', 'mask_float'])
	assert(gan.netG_DIV.call_args[0][0].name == 'fake_discrete_output')
	assert(gan.netG_Vx.call_args[0][0].name == 'fake_discrete_output')
	assert(gan.netG_Vy.call_args[0][0].name == 'fake_discrete_output')
コード例 #4
0
def test_generator(basic_gan):
	MockGenerator = basic_gan

	gan = Pix2PixGeoModel()

	opt = standard_options()
	gan.initialize(opt)

	fake_dataset = DatasetMock()

	gan.set_input(fake_dataset)
	gan.forward()

	assert(MockGenerator.call_args[0][0].name == (['A', 'mask_float']))
コード例 #5
0
def test_with_continents(basic_gan, mocker):
	MockGenerator = basic_gan

	gan = Pix2PixGeoModel()

	opt = standard_options()
	opt.num_discrims = 1
	opt.low_iter = 1
	opt.high_iter = 1

	opt.lambda_A = 1
	opt.lambda_B = 1
	opt.lambda_C = 1
	opt.discrete_only = False
	opt.local_loss = False
	opt.continent_data = True

	gan.initialize(opt)

	assert(all([netD1 != netD2 for netD1, netD2 in zip(gan.netD1s, gan.netD2s)]))

	gan.netG_DIV = fake_network(mocker, 'fake_DIV')
	gan.netG_Vx = fake_network(mocker, 'fake_Vx')
	gan.netG_Vy = fake_network(mocker, 'fake_Vy')

	fake_dataset = DatasetMock()

	gan.set_input(fake_dataset)

	gan.optimize_parameters(step_no=1)

	assert(MockGenerator.call_args[0][0].name == (['A', 'mask_float', 'cont_float']))
	assert(gan.netG_DIV.call_args[0][0].name == ('fake_discrete_output'))
	assert(gan.netG_Vx.call_args[0][0].name == ('fake_discrete_output'))
	assert(gan.netG_Vy.call_args[0][0].name == ('fake_discrete_output'))


	for netD1 in gan.netD1s:
		assert(netD1.call_args_list[0][0][0].name == '[A, mask_float, cont_float, fake_discrete_output]_detach')
		assert(netD1.call_args_list[1][0][0].name == ['A', 'mask_float', 'cont_float', 'B'])
	
		# Third call to grad penalty??

	for netD2 in gan.netD2s:
		assert(netD2.call_args_list[0][0][0].name == '[A, mask_float, cont_float, fake_DIV, fake_Vx, fake_Vy]_detach')
		assert(netD2.call_args_list[1][0][0].name == ['A', 'mask_float', 'cont_float', 'B_DIV', 'B_Vx', 'B_Vy'])
コード例 #6
0
def test_div_only(basic_gan, mocker):
	opt = standard_options()
	opt.num_discrims = 1
	opt.div_only = True

	gan = Pix2PixGeoModel()
	gan.initialize(opt)

	assert(models.networks.define_D.called)
	assert(models.networks.define_D.call_args[0][0] == 5)

	gan.netG_DIV = fake_network(mocker, 'fake_DIV')

	fake_dataset = DatasetMock()

	gan.set_input(fake_dataset)
	gan.optimize_parameters(step_no=1)

	assert(gan.netG.call_args[0][0].name == ['A', 'mask_float'])

	assert(gan.netD2s[0].call_args_list[0][0][0].name == '[A, mask_float, fake_DIV]_detach')
	assert(gan.netD2s[0].call_args_list[1][0][0].name == ['A', 'mask_float', 'B_DIV'])
コード例 #7
0
def test_folder_id_used_in_cross_entropy_loss(basic_gan, mocker):
	MockGenerator = basic_gan
	MockInnerLayer = fake_network(mocker, 'innermost')

	mocker.patch('models.pix2pix_geo_model.get_innermost',
		new=mocker.MagicMock(return_value=MockInnerLayer))

	gan = Pix2PixGeoModel()

	opt = standard_options()
	opt.num_discrims = 1
	opt.low_iter = 1
	opt.high_iter = 1

	opt.which_model_netD = 'cwgan-gp'
	opt.lambda_A = 1
	opt.lambda_B = 1
	opt.lambda_C = 1
	opt.discrete_only = False
	opt.local_loss = False
	opt.fineSize = 256
	opt.num_folders = 20


	mocker.patch('torch.nn.Linear')
	mocker.patch('models.pix2pix_geo_model.get_downsample', return_value=32)
	gan.initialize(opt)

	torch.nn.Linear.assert_not_called()

	# Check this switch is working
	opt.folder_pred = True

	mocker.patch('torch.nn.Linear')
	mocker.patch('models.pix2pix_geo_model.get_downsample', return_value=32)
	gan.initialize(opt)

	torch.nn.Linear.assert_called_with(2*256**2 / 32**2 * 420*8, 20)

	gan.netG = MockGenerator
	
	gan.netG_DIV = fake_network(mocker, 'fake_DIV')
	gan.netG_Vx = fake_network(mocker, 'fake_Vx')
	gan.netG_Vy = fake_network(mocker, 'fake_Vy')

	fake_dataset = DatasetMock()

	gan.set_input(fake_dataset)
	models.pix2pix_geo_model.get_innermost.assert_called_with(gan.netG, 'UnetSkipConnectionBlock')

	assert(gan.netG.inner_layer == MockInnerLayer)
	assert(MockInnerLayer.register_forward_hook.called)
	MockInnerLayer.register_forward_hook.assert_called_with(models.pix2pix_geo_model.save_output_hook)
	
	gan.folder_fc = fake_network(mocker, 'fake_folder')

	ce_fun_mock = mocker.MagicMock()
	gan.criterionCE = mocker.MagicMock(return_value=ce_fun_mock)

	gan.optimize_parameters(step_no=1)

	print(gan.fake_folder.name)
	assert(gan.fake_folder.name == 'fake_folder_softmax')
	assert(gan.real_folder.name == 'folder_id')

	assert(gan.netD1s[0].call_args_list[0][0][0].name == '[A, mask_float, fake_discrete_output]_detach')
	
	assert(gan.netD1s[0].call_args_list[1][0][0].name 	== ['A', 'mask_float', 'B'])
	
	assert(len(gan.netD1s[0].call_args_list)		== 3)

	assert(gan.netD2s[0].call_args_list[0][0][0].name == '[A, mask_float, fake_DIV, fake_Vx, fake_Vy]_detach')
	
	assert(gan.netD2s[0].call_args_list[1][0][0].name 	== ['A', 'mask_float', 'B_DIV', 'B_Vx', 'B_Vy'])

	assert(len(gan.netD2s[0].call_args_list)		== 3)

	assert(len(ce_fun_mock.call_args_list) == 2)
	assert(ce_fun_mock.call_args_list[1][0][0].name =='fake_folder_softmax')
	assert(ce_fun_mock.call_args_list[1][0][1].name =='folder_id')
コード例 #8
0
def test_generator_discriminator(basic_gan, mocker):
	MockGenerator = basic_gan

	gan = Pix2PixGeoModel()

	opt = standard_options()
	opt.num_discrims = 1
	opt.low_iter = 1
	opt.high_iter = 1

	opt.lambda_A = 1
	opt.lambda_B = 1
	opt.lambda_C = 1
	opt.discrete_only = False
	opt.local_loss = False

	gan.initialize(opt)

	assert(models.networks.define_D.called)
	assert(models.networks.define_D.call_args[0][0] == 7)

	assert(all([netD1 != netD2 for netD1, netD2 in zip(gan.netD1s, gan.netD2s)]))

	gan.netG_DIV = fake_network(mocker, 'fake_DIV')
	gan.netG_Vx = fake_network(mocker, 'fake_Vx')
	gan.netG_Vy = fake_network(mocker, 'fake_Vy')

	fake_dataset = DatasetMock()

	gan.set_input(fake_dataset)

	# gan.backward_D = mocker.MagicMock(return_value=('loss', 'loss', 'loss', 'loss'))
	# gan.backward_single_D = mocker.MagicMock(return_value=VariableMock().set_name('loss'))

	gan.optimize_parameters(step_no=1)

	assert(MockGenerator.call_args[0][0].name == (['A', 'mask_float']))
	assert(gan.netG_DIV.call_args[0][0].name == ('fake_discrete_output'))
	assert(gan.netG_Vx.call_args[0][0].name == ('fake_discrete_output'))
	assert(gan.netG_Vy.call_args[0][0].name == ('fake_discrete_output'))

	# First call -> non-keyword args -> argument index
	# assert(gan.backward_D.call_args_list[0][0][0] == gan.netD1s)
	# assert(gan.backward_D.call_args_list[0][0][2].name == ['A', 'mask_float'])
	# assert(gan.backward_D.call_args_list[0][0][3].name == 'B')
	# assert(gan.backward_D.call_args_list[0][0][4].name == 'fake_discrete_output')

	# assert(gan.backward_D.call_args_list[1][0][0] == gan.netD2s)
	# assert(gan.backward_D.call_args_list[1][0][2].name == ['A', 'mask_float'])
	# assert(gan.backward_D.call_args_list[1][0][3].name == ['B_DIV', 'B_Vx', 'B_Vy'])
	# assert(gan.backward_D.call_args_list[1][0][4].name == ['fake_DIV', 'fake_Vx', 'fake_Vy'])

	# assert(gan.backward_single_D.call_args_list[0][0][0] == gan.netD1s[0])
	# assert(gan.backward_single_D.call_args_list[0][0][1].name == ['A', 'mask_float'])
	# assert(gan.backward_single_D.call_args_list[0][0][2].name == 'B')
	# assert(gan.backward_single_D.call_args_list[0][0][3].name == 'fake_discrete_output')

	# assert(gan.backward_single_D.call_args_list[1][0][0] == gan.netD2s[0])
	# assert(gan.backward_single_D.call_args_list[1][0][1].name == ['A', 'mask_float'])
	# assert(gan.backward_single_D.call_args_list[1][0][2].name == ['B_DIV', 'B_Vx', 'B_Vy'])
	# assert(gan.backward_single_D.call_args_list[1][0][3].name == ['fake_DIV', 'fake_Vx', 'fake_Vy'])


	for netD1 in gan.netD1s:
		assert(netD1.call_args_list[0][0][0].name == '[A, mask_float, fake_discrete_output]_detach')
		assert(netD1.call_args_list[1][0][0].name == ['A', 'mask_float', 'B'])
	
		# Third call to grad penalty??

	for netD2 in gan.netD2s:
		assert(netD2.call_args_list[0][0][0].name == '[A, mask_float, fake_DIV, fake_Vx, fake_Vy]_detach')
		assert(netD2.call_args_list[1][0][0].name == ['A', 'mask_float', 'B_DIV', 'B_Vx', 'B_Vy'])