コード例 #1
0
ファイル: test_core.py プロジェクト: lorin/cider
    def test_read_defaults(self, tmpdir, debug, verbose, data):
        with patch("cider.core.read_json") as mock:
            cider = Cider(False, debug, verbose, cider_dir=str(tmpdir))
            mock.return_value = data

            assert cider.read_defaults() == data
            mock.assert_called_with(cider.defaults_file, {})
コード例 #2
0
ファイル: test_core.py プロジェクト: a-b/cider
    def test_mklink(self, tmpdir, debug, verbose):
        cider = Cider(
            False, debug, verbose,
            cider_dir=str(tmpdir),
            support_dir=str(tmpdir.join(".cache"))
        )
        source = str(tmpdir.join(random_str(min_length=1)))
        target = str(tmpdir.join(random_str(min_length=1)))

        # SymlinkError should be raised if source does not exist.
        with pytest.raises(SymlinkError):
            assert not cider.mklink(source, target)

        # Should succeed for valid source/target.
        touch(source)
        for _ in range(2):
            assert cider.mklink(source, target)
            assert os.path.islink(target)

        # Should fail for existing target.
        os.remove(target)
        touch(target)
        assert not cider.mklink(source, target)
        assert not os.path.islink(target)

        # Should allow removing existing target with --force.
        with patch("cider._osx.move_to_trash", side_effect=os.remove):
            assert cider.mklink(source, target, force=True)
コード例 #3
0
ファイル: test_core.py プロジェクト: a-b/cider
    def test_missing_taps(self, tmpdir, debug, verbose, installed, brewed):
        cider = Cider(False, debug, verbose, cider_dir=str(tmpdir))
        cider.brew = MagicMock()
        cider.brew.tap = MagicMock(return_value="\n".join(brewed))

        missing = set(brewed) - set(installed)
        assert cider.missing_taps() == sorted(missing)
コード例 #4
0
ファイル: test_core.py プロジェクト: timgates42/cider
    def test_read_defaults(self, tmpdir, debug, verbose, data):
        with patch("cider.core.read_config") as mock:
            cider = Cider(False, debug, verbose, cider_dir=str(tmpdir))
            mock.return_value = data

            assert cider.read_defaults() == data
            mock.assert_called_with(cider.defaults_file, {})
コード例 #5
0
ファイル: test_core.py プロジェクト: timgates42/cider
    def test_missing_taps(self, tmpdir, debug, verbose, installed, brewed):
        cider = Cider(False, debug, verbose, cider_dir=str(tmpdir))
        cider.brew = MagicMock()
        cider.brew.tap = MagicMock(return_value="\n".join(brewed))

        missing = set(brewed) - set(installed)
        assert cider.missing_taps() == sorted(missing)
コード例 #6
0
ファイル: test_core.py プロジェクト: lorin/cider
    def test_apply_defaults(self, tmpdir, debug, verbose, defaults):
        cider = Cider(False, debug, verbose, cider_dir=str(tmpdir))
        cider.defaults = MagicMock()
        cider.read_defaults = MagicMock(return_value=defaults)
        cider.apply_defaults()

        for domain, options in defaults.items():
            for key, value in options.items():
                cider.defaults.write.assert_any_call(domain, key, value)
コード例 #7
0
ファイル: test_core.py プロジェクト: lorin/cider
    def test_rm(self, tmpdir, cask, debug, verbose, formulas):
        cider = Cider(cask, debug, verbose, cider_dir=str(tmpdir))
        cider.brew = MagicMock()

        cider.rm(*formulas)
        cider.brew.rm.assert_called_once_with(*formulas)
        key = "casks" if cask else "formulas"
        for formula in formulas:
            assert formula not in cider.read_bootstrap().get(key, [])
コード例 #8
0
 def CIDEr_call(self):
     cider_instance = Cider()
     f_real = open(self.real_file,'r')
     f_predict = open(self.predict_file,'r')
     references = []
     candidates = []
     for real_caption, predict_caption in zip(f_real,f_predict):
         references.append([real_caption])
         candidates.append(predict_caption)
     cider_score = cider_instance.compute_score(candidates,references)
     self.f.write("CIDEr score: " + str(cider_score) + '\n\n')
コード例 #9
0
ファイル: test_core.py プロジェクト: lorin/cider
    def test_read_bootstrap(self, tmpdir, cask, debug, verbose, data):
        with patch("cider.core.read_json") as mock:
            cider = Cider(cask, debug, verbose, cider_dir=str(tmpdir))
            mock.return_value = data
            assert cider.read_bootstrap() == data
            mock.assert_called_with(cider.bootstrap_file)

            mock.side_effect = IOError(errno.ENOENT, "")
            with pytest.raises(BootstrapMissingError):
                cider.read_bootstrap()
                mock.assert_called_with(cider.bootstrap_file)
コード例 #10
0
ファイル: test_core.py プロジェクト: timgates42/cider
    def test_installed(self, tmpdir, cask, debug, verbose, random_prefix,
                       bootstrap):
        cider = Cider(cask, debug, verbose, cider_dir=str(tmpdir))
        cider.read_bootstrap = MagicMock(return_value=bootstrap)

        key = "casks" if cask else "formulas"
        installed = bootstrap.get(key, [])
        random_choice = random.choice(installed) if installed else None
        for prefix in [None, random_choice, random_prefix]:
            assert cider.installed(prefix) == [
                x for x in installed if not prefix or x.startswith(prefix)
            ]
コード例 #11
0
ファイル: test_core.py プロジェクト: lorin/cider
    def test_installed(self, tmpdir, cask, debug, verbose,
                       random_prefix, bootstrap):
        cider = Cider(cask, debug, verbose, cider_dir=str(tmpdir))
        cider.read_bootstrap = MagicMock(return_value=bootstrap)

        key = "casks" if cask else "formulas"
        installed = bootstrap.get(key, [])
        random_choice = random.choice(installed) if installed else None
        for prefix in [None, random_choice, random_prefix]:
            assert cider.installed(prefix) == [
                x for x in installed if not prefix or x.startswith(prefix)
            ]
コード例 #12
0
ファイル: test_core.py プロジェクト: lorin/cider
    def test_run_scripts(self, tmpdir, debug, verbose, before,
                         after, bootstrap):
        cider = Cider(False, debug, verbose, cider_dir=str(tmpdir))
        cider.read_bootstrap = MagicMock(return_value=bootstrap)
        scripts = []
        scripts += bootstrap.get("before-scripts", []) if before else []
        scripts += bootstrap.get("after-scripts", []) if after else []

        # TODO: Assert ordering
        with patch("cider.core.spawn", autospec=True, return_value=0) as spawn:
            cider.run_scripts(before, after)
            for script in scripts:
                spawn.assert_any_call(
                    [script], shell=True, debug=debug, cwd=cider.cider_dir
                )
コード例 #13
0
 def __init__(self, device, args):
     self.device = device
     self.args = args
     self.batch_size = args.batch_size
     self.checkpoint_path = os.path.join(args.checkpoint_dir, 'generator.pth')
     if not os.path.isdir(args.checkpoint_dir):
         os.mkdir(args.checkpoint_dir)
     self.generator = Generator(args).to(self.device)
     self.sequence_loss = SequenceLoss()
     self.reinforce_loss = ReinforceLoss()
     self.optimizer = optim.Adam(self.generator.parameters(), lr=args.learning_rate)
     self.evaluator = Evaluator('val', self.device, args)
     self.cider = Cider(args)
     dataset = CaptionDataset('train', args)
     self.loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True, num_workers=4)
コード例 #14
0
def eval_captions(gt_captions, res_captions):
	"""
		gt_captions = ground truth captions; 5 per image
		res_captions = captions generated by the model to be evaluated
	"""
	print('ground truth captions')
	print(gt_captions)

	print('RES CAPTIONS')
	print(res_captions)

	scorers = [
		(Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
		(Meteor(),"METEOR"),
		(Rouge(), "ROUGE_L"),
		(Cider(), "CIDEr"), 
	]

	res = []
	
	for scorer, method in scorers:
		print('computing %s score...' % (scorer.method()))
		score, scores = scorer.compute_score(gt_captions, res_captions)
		if type(method) == list:
			for sc, scs, m in zip(score, scores, method):
				print("%s: %0.3f"%(m, sc))
				res.append((m, sc))
		else:
				print("%s: %0.3f"%(method, score))
				res.append((method, score))

	return res
コード例 #15
0
ファイル: test_core.py プロジェクト: timgates42/cider
    def test_addlink(self, tmpdir, debug, verbose, name):
        """
        Tests that:
        1. Symlink directory is created & file is moved there.
        2. Symlink is created.
        3. Symlink is added to bootstrap.
        4. Cache is updated with new target.

        Expected errors:
        - StowError is raised if target does not exist.
        - StowError is raised if target already exists.
        """
        cider = Cider(False,
                      debug,
                      verbose,
                      cider_dir=str(tmpdir),
                      support_dir=str(tmpdir.join(".cache")))
        cider.add_symlink = MagicMock()

        source = os.path.abspath(str(tmpdir.join(random_str(min_length=1))))
        basename = os.path.basename(source)

        stow_dir = os.path.abspath(os.path.join(cider.symlink_dir, name))
        stow = os.path.join(stow_dir, basename)

        # StowError should be raised if source does not exist.
        with pytest.raises(StowError):
            cider.addlink(name, source)

        touch(source)
        cider.addlink(name, source)
        assert os.path.isdir(stow_dir)
        assert os.path.isfile(stow)
        assert os.path.islink(source)
        assert os.path.samefile(os.path.realpath(stow),
                                os.path.realpath(source))

        cider.add_symlink.assert_called_with(name, source)
        new_cache = cider._cached_targets()  # pylint:disable=W0212
        assert source in new_cache

        # StowError should be raised if source already exists.
        os.remove(source)
        touch(source)
        with pytest.raises(StowError):
            cider.addlink(source, name)
コード例 #16
0
ファイル: test_core.py プロジェクト: timgates42/cider
    def test_run_scripts(self, tmpdir, debug, verbose, before, after,
                         bootstrap):
        cider = Cider(False, debug, verbose, cider_dir=str(tmpdir))
        cider.read_bootstrap = MagicMock(return_value=bootstrap)
        scripts = []
        scripts += bootstrap.get("before-scripts", []) if before else []
        scripts += bootstrap.get("after-scripts", []) if after else []

        # TODO: Assert ordering
        with patch("cider.core.spawn", autospec=True, return_value=0) as spawn:
            cider.run_scripts(before, after)
            for script in scripts:
                spawn.assert_any_call([script],
                                      shell=True,
                                      debug=debug,
                                      cwd=cider.cider_dir,
                                      env=cider.env)
コード例 #17
0
ファイル: test_core.py プロジェクト: timgates42/cider
    def test_set_default(self, tmpdir, debug, verbose, domain, key, values,
                         bool_values, force):
        cider = Cider(False, debug, verbose, cider_dir=str(tmpdir))
        cider.defaults = MagicMock()

        for value in values + map(random_case, bool_values):
            json_value = cider.json_value(value)
            cider.set_default(domain, key, value, force=force)
            cider.defaults.write.assert_called_with(domain, key, json_value,
                                                    force)

            assert cider.read_defaults()[domain][key] == json_value

            # Verify str(value) => defaults.write(value)
            cider.set_default(domain, key, str(value), force=force)
            cider.defaults.write.assert_called_with(
                domain, key, cider.json_value(str(value)), force)
コード例 #18
0
    def evaluate(self):
        cap = open(r'results.txt')
        cap_ = []
        for line in cap:
            line = line.split(' ')
            line[len(line)-1] = '.'
            del line[0]
            print(line)
            cap_.append(line)
        gts = {}
        res = {}
        f = open("cap_flickr30k.json")
        captions = json.load(f)
        f1 = open("dic_flickr30k.json")
        dics = json.load(f1)
        dics = dics['images']
        pos = 0
        for i in range(0, len(dics), 1):
            if dics[i]['split'] == 'test':
                caption_1 = []
                caption_2 = []
                caption_1.append(captions[i][0]['caption'])
                res[dics[i]['id']] = caption_1
                caption_2.append(cap_[pos])
                caption_2.append(cap_[pos])
                gts[dics[i]['id']] = caption_2
                pos = pos + 1

        # =================================================
        # Set up scorers
        # =================================================

        # =================================================
        # Set up scorers
        # =================================================
        print('setting up scorers...')
        scorers = [
            (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
            (Cider(), "CIDEr")
        ]

        # =================================================
        # Compute scores
        # =================================================
        eval = {}
        for scorer, method in scorers:
            print ('computing %s score...'%(scorer.method()))
            score, scores = scorer.compute_score(gts, res)
            if type(method) == list:
                for sc, scs, m in zip(score, scores, method):
                    self.setEval(sc, m)
                    print ("%s: %0.3f"%(m, sc))
            else:
                self.setEval(score, method)
                print ("%s: %0.3f"%(method, score))
コード例 #19
0
ファイル: test_core.py プロジェクト: a-b/cider
    def test_addlink(self, tmpdir, debug, verbose, name):
        """
        Tests that:
        1. Symlink directory is created & file is moved there.
        2. Symlink is created.
        3. Symlink is added to bootstrap.
        4. Cache is updated with new target.

        Expected errors:
        - StowError is raised if target does not exist.
        - StowError is raised if target already exists.
        """
        cider = Cider(
            False, debug, verbose,
            cider_dir=str(tmpdir),
            support_dir=str(tmpdir.join(".cache"))
        )
        cider.add_symlink = MagicMock()

        source = os.path.abspath(str(tmpdir.join(random_str(min_length=1))))
        basename = os.path.basename(source)

        stow_dir = os.path.abspath(os.path.join(cider.symlink_dir, name))
        stow = os.path.join(stow_dir, basename)

        # StowError should be raised if source does not exist.
        with pytest.raises(StowError):
            cider.addlink(name, source)

        touch(source)
        cider.addlink(name, source)
        assert os.path.isdir(stow_dir)
        assert os.path.isfile(stow)
        assert os.path.islink(source)
        assert os.path.samefile(
            os.path.realpath(stow), os.path.realpath(source)
        )

        cider.add_symlink.assert_called_with(name, source)
        new_cache = cider._cached_targets()  # pylint:disable=W0212
        assert source in new_cache

        # StowError should be raised if source already exists.
        os.remove(source)
        touch(source)
        with pytest.raises(StowError):
            cider.addlink(source, name)
コード例 #20
0
ファイル: test_core.py プロジェクト: timgates42/cider
    def test_mklink(self, tmpdir, debug, verbose):
        cider = Cider(False,
                      debug,
                      verbose,
                      cider_dir=str(tmpdir),
                      support_dir=str(tmpdir.join(".cache")))
        source = str(tmpdir.join(random_str(min_length=1)))
        target = str(tmpdir.join(random_str(min_length=1)))

        # SymlinkError should be raised if source does not exist.
        with pytest.raises(SymlinkError):
            assert not cider.mklink(source, target)

        # Should succeed for valid source/target.
        touch(source)
        for _ in range(2):
            assert cider.mklink(source, target)
            assert os.path.islink(target)

        # Should fail for existing target.
        os.remove(target)
        touch(target)
        assert not cider.mklink(source, target)
        assert not os.path.islink(target)

        # Should allow removing existing target with --force.
        with patch("cider._osx.move_to_trash", side_effect=os.remove):
            assert cider.mklink(source, target, force=True)
コード例 #21
0
ファイル: test_core.py プロジェクト: timgates42/cider
    def test_missing(self, tmpdir, cask, debug, verbose, installed, brewed):
        orphans = []

        def generate_uses():
            uses = {}
            for formula in brewed:
                subset = [x for x in installed if x != formula]
                if subset and random.choice([True, False]):
                    uses[formula] = random.sample(
                        subset, random.randint(1, len(subset)))
                else:
                    orphans.append(formula)

            return lambda x: uses.get(x, [])

        cider = Cider(cask, debug, verbose, cider_dir=str(tmpdir))
        cider.brew = MagicMock()
        cider.brew.ls = MagicMock(return_value=brewed)
        cider.brew.uses = MagicMock(side_effect=generate_uses())
        cider.installed = MagicMock(return_value=installed)

        assert cider.missing() == sorted(orphans)
コード例 #22
0
ファイル: test_core.py プロジェクト: timgates42/cider
    def test_apply_defaults(self, tmpdir, debug, verbose, defaults):
        cider = Cider(False, debug, verbose, cider_dir=str(tmpdir))
        cider.defaults = MagicMock()
        cider.read_defaults = MagicMock(return_value=defaults)
        cider.apply_defaults()

        for domain, options in defaults.items():
            for key, value in options.items():
                cider.defaults.write.assert_any_call(domain, key, value)
コード例 #23
0
ファイル: test_core.py プロジェクト: timgates42/cider
    def test_rm(self, tmpdir, cask, debug, verbose, formulas):
        cider = Cider(cask, debug, verbose, cider_dir=str(tmpdir))
        cider.brew = MagicMock()

        cider.rm(*formulas)
        cider.brew.rm.assert_called_once_with(*formulas)
        key = "casks" if cask else "formulas"
        for formula in formulas:
            assert formula not in cider.read_bootstrap().get(key, [])
コード例 #24
0
ファイル: test_core.py プロジェクト: lorin/cider
    def test_set_default(
        self, tmpdir, debug, verbose, domain, key, values, bool_values, force
    ):
        cider = Cider(False, debug, verbose, cider_dir=str(tmpdir))
        cider.defaults = MagicMock()

        for value in values + map(random_case, bool_values):
            json_value = cider.json_value(value)
            cider.set_default(domain, key, value, force=force)
            cider.defaults.write.assert_called_with(
                domain, key, json_value, force
            )

            assert cider.read_defaults()[domain][key] == json_value

            # Verify str(value) => defaults.write(value)
            cider.set_default(domain, key, str(value), force=force)
            cider.defaults.write.assert_called_with(
                domain, key, cider.json_value(str(value)), force
            )
コード例 #25
0
ファイル: test_core.py プロジェクト: lorin/cider
    def test_missing(self, tmpdir, cask, debug, verbose,
                     installed, brewed):
        orphans = []

        def generate_uses():
            uses = {}
            for formula in brewed:
                subset = [x for x in installed if x != formula]
                if subset and random.choice([True, False]):
                    uses[formula] = random.sample(subset, random.randint(
                        1, len(subset)
                    ))
                else:
                    orphans.append(formula)

            return lambda x: uses.get(x, [])

        cider = Cider(cask, debug, verbose, cider_dir=str(tmpdir))
        cider.brew = MagicMock()
        cider.brew.ls = MagicMock(return_value=brewed)
        cider.brew.uses = MagicMock(side_effect=generate_uses())
        cider.installed = MagicMock(return_value=installed)

        assert cider.missing() == sorted(orphans)
コード例 #26
0
def score(ref, hypo):
    scorers = [
        (Bleu(4),["Bleu_1","Bleu_2","Bleu_3","Bleu_4"]),
        (Meteor(),"METEOR"),
        (Rouge(),"ROUGE_L"),
        (Cider(),"CIDEr")
    ]
    final_scores = {}
    for scorer,method in scorers:
        score,scores = scorer.compute_score(ref,hypo)
        if type(score)==list:
            for m,s in zip(method,score):
                final_scores[m] = s
        else:
            final_scores[method] = score

    return final_scores
コード例 #27
0
ファイル: test_core.py プロジェクト: timgates42/cider
    def test_unlink(self, tmpdir, debug, verbose, name, links):
        """
        Tests that:
        1. Each symlink is moved back to its original location.
        2. Symlinks are removed from bootstrap.
        3. Cache is updated with targets removed.
        4. Symlink directory is removed if empty.

        Expected errors:
        - StowError is raised if no symlink was found.
        - SymlinkError is raised if target already exists.
        """
        cider = Cider(False,
                      debug,
                      verbose,
                      cider_dir=str(tmpdir.join("cider")),
                      support_dir=str(tmpdir.join("cider", ".cache")))
        cider.remove_symlink = MagicMock()

        stow_dir = os.path.abspath(os.path.join(cider.symlink_dir, name))
        os.makedirs(stow_dir)
        for link in links:
            source = os.path.join(stow_dir, link)
            target = str(tmpdir.join(link))
            touch(source)
            os.symlink(source, target)
            cider.add_symlink(name, target)

        cider.unlink(name)

        new_cache = cider._cached_targets()  # pylint:disable=W0212
        for link in links:
            source = os.path.join(stow_dir, link)
            target = str(tmpdir.join(link))
            assert os.path.exists(target)
            assert not os.path.islink(target)
            assert not os.path.exists(source)
            assert target not in new_cache

        cider.remove_symlink.assert_called_with(name)
        assert not os.path.exists(stow_dir)
コード例 #28
0
ファイル: test_core.py プロジェクト: a-b/cider
    def test_unlink(self, tmpdir, debug, verbose, name, links):
        """
        Tests that:
        1. Each symlink is moved back to its original location.
        2. Symlinks are removed from bootstrap.
        3. Cache is updated with targets removed.
        4. Symlink directory is removed if empty.

        Expected errors:
        - StowError is raised if no symlink was found.
        - SymlinkError is raised if target already exists.
        """
        cider = Cider(
            False, debug, verbose,
            cider_dir=str(tmpdir.join("cider")),
            support_dir=str(tmpdir.join("cider", ".cache"))
        )
        cider.remove_symlink = MagicMock()

        stow_dir = os.path.abspath(os.path.join(cider.symlink_dir, name))
        os.makedirs(stow_dir)
        for link in links:
            source = os.path.join(stow_dir, link)
            target = str(tmpdir.join(link))
            touch(source)
            os.symlink(source, target)
            cider.add_symlink(name, target)

        cider.unlink(name)

        new_cache = cider._cached_targets()  # pylint:disable=W0212
        for link in links:
            source = os.path.join(stow_dir, link)
            target = str(tmpdir.join(link))
            assert os.path.exists(target)
            assert not os.path.islink(target)
            assert not os.path.exists(source)
            assert target not in new_cache

        cider.remove_symlink.assert_called_with(name)
        assert not os.path.exists(stow_dir)
コード例 #29
0
class GAN:
    def __init__(self, device, args):
        self.device = device
        self.args = args
        self.batch_size = args.batch_size
        self.generator_checkpoint_path = os.path.join(args.checkpoint_path, 'generator.pth')
        self.discriminator_checkpoint_path = os.path.join(args.checkpoint_path, 'discriminator.pth')
        if not os.path.isdir(args.checkpoint_path):
            os.mkdir(args.checkpoint_path)
        self.generator = Generator(args).to(self.device)
        self.discriminator = Discriminator(args).to(self.device)
        self.sequence_loss = SequenceLoss()
        self.reinforce_loss = ReinforceLoss()
        self.generator_optimizer = optim.Adam(self.generator.parameters(), lr=args.generator_lr)
        self.discriminator_optimizer = optim.Adam(self.discriminator.parameters(), lr=args.discriminator_lr)
        self.evaluator = Evaluator('val', self.device, args)
        self.cider = Cider(args)
        generator_dataset = CaptionDataset('train', args)
        self.generator_loader = DataLoader(generator_dataset, batch_size=self.batch_size, shuffle=True, num_workers=4)
        discriminator_dataset = DiscCaption('train', args)
        self.discriminator_loader = DataLoader(discriminator_dataset, batch_size=self.batch_size, shuffle=True, num_workers=4)

    def train(self):
        if self.args.load_generator:
            self.generator.load_state_dict(torch.load(self.generator_checkpoint_path))
        else:
            self._pretrain_generator()
        if self.args.load_discriminator:
            self.discriminator.load_state_dict(torch.load(self.discriminator_checkpoint_path))
        else:
            self._pretrain_discriminator()
        self._train_gan()

    def _pretrain_generator(self):
        iter = 0
        for epoch in range(self.args.pretrain_generator_epochs):
            self.generator.train()
            for data in self.generator_loader:
                for name, item in data.items():
                    data[name] = item.to(self.device)
                self.generator.zero_grad()
                probs = self.generator(data['fc_feats'], data['att_feats'], data['att_masks'], data['labels'])
                loss = self.sequence_loss(probs, data['labels'])
                loss.backward()
                self.generator_optimizer.step()
                print('iter {}, epoch {}, generator loss {:.3f}'.format(iter, epoch, loss.item()))
                iter += 1
            self.evaluator.evaluate_generator(self.generator)
            torch.save(self.generator.state_dict(), self.generator_checkpoint_path)

    def _pretrain_discriminator(self):
        iter = 0
        for epoch in range(self.args.pretrain_discriminator_epochs):
            self.discriminator.train()
            for data in self.discriminator_loader:
                loss = self._train_discriminator(data)
                print('iter {}, epoch {}, discriminator loss {:.3f}'.format(iter, epoch, loss))
                iter += 1
            self.evaluator.evaluate_discriminator(generator=self.generator, discriminator=self.discriminator)
            torch.save(self.discriminator.state_dict(), self.discriminator_checkpoint_path)
            
    def _train_gan(self):
        generator_iter = iter(self.generator_loader)
        discriminator_iter = iter(self.discriminator_loader)
        for i in range(self.args.train_gan_iters):
            print('iter {}'.format(i))
            for j in range(1):
                try:
                    data = next(generator_iter)
                except StopIteration:
                    generator_iter = iter(self.generator_loader)
                    data = next(generator_iter)
                result = self._train_generator(data)
                print('generator loss {:.3f}, fake prob {:.3f}, cider score {:.3f}'.format(result['loss'], result['fake_prob'], result['cider_score']))
            for j in range(1):
                try:
                    data = next(discriminator_iter)
                except StopIteration:
                    discriminator_iter = iter(self.discriminator_loader)
                    data = next(discriminator_iter)
                loss = self._train_discriminator(data)
                print('discriminator loss {:.3f}'.format(loss))
            if i != 0 and i % 10000 == 0:
                self.evaluator.evaluate_generator(self.generator)
                torch.save(self.generator.state_dict(), self.generator_checkpoint_path)
                self.evaluator.evaluate_discriminator(generator=self.generator, discriminator=self.discriminator)
                torch.save(self.discriminator.state_dict(), self.discriminator_checkpoint_path)

    def _train_generator(self, data):
        self.generator.train()
        for name, item in data.items():
            data[name] = item.to(self.device)
        self.generator.zero_grad()

        probs = self.generator(data['fc_feats'], data['att_feats'], data['att_masks'], data['labels'])
        loss1 = self.sequence_loss(probs, data['labels'])

        seqs, probs = self.generator.sample(data['fc_feats'], data['att_feats'], data['att_masks'])
        greedy_seqs = self.generator.greedy_decode(data['fc_feats'], data['att_feats'], data['att_masks'])
        reward, fake_prob, score = self._get_reward(data, seqs)
        baseline, _, _ = self._get_reward(data, greedy_seqs)
        loss2 = self.reinforce_loss(reward, baseline, probs, seqs)

        loss = loss1 + loss2
        loss.backward()
        self.generator_optimizer.step()
        result = {
            'loss': loss1.item(),
            'fake_prob': fake_prob,
            'cider_score': score
        }
        return result

    def _train_discriminator(self, data):
        self.discriminator.train()
        for name, item in data.items():
            data[name] = item.to(self.device)
        self.discriminator.zero_grad()

        real_probs = self.discriminator(data['fc_feats'], data['att_feats'], data['att_masks'], data['labels'])
        wrong_probs = self.discriminator(data['fc_feats'], data['att_feats'], data['att_masks'], data['wrong_labels'])

        # generate fake data
        with torch.no_grad():
            fake_seqs, _ = self.generator.sample(data['fc_feats'], data['att_feats'], data['att_masks'])
        fake_probs = self.discriminator(data['fc_feats'], data['att_feats'], data['att_masks'], fake_seqs)

        loss = -(0.5 * torch.log(real_probs + 1e-10) + 0.25 * torch.log(1 - wrong_probs + 1e-10) + 0.25 * torch.log(1 - fake_probs + 1e-10)).mean()
        loss.backward()
        self.discriminator_optimizer.step()
        return loss.item()

    def _get_reward(self, data, seqs):
        probs = self.discriminator(data['fc_feats'], data['att_feats'], data['att_masks'], seqs)
        scores = self.cider.get_scores(seqs.cpu().numpy(), data['images'].cpu().numpy())
        reward = probs + torch.tensor(scores, dtype=torch.float, device=self.device)
        fake_prob = probs.mean().item()
        score = scores.mean()
        return reward, fake_prob, score
コード例 #30
0
ファイル: test_core.py プロジェクト: timgates42/cider
 def test_untap(self, tmpdir, debug, verbose, tap):
     cider = Cider(False, debug, verbose, cider_dir=str(tmpdir))
     cider.brew = MagicMock()
     cider.untap(tap)
     cider.brew.untap.assert_called_with(tap)
     assert tap not in cider.read_bootstrap().get("taps", [])
コード例 #31
0
ファイル: test_core.py プロジェクト: timgates42/cider
    def test_relink(self, tmpdir, debug, verbose, force):
        """
        Tests that:
        1. Target directories are created.
        2. For each source in glob(key), mklink(src, expandtarget(src, target))
           is called.
        3. Previously-cached targets are removed.
        4. Cache is updated with new targets.
        """
        def generate_symlinks():
            srcdir = tmpdir.join(random_str(min_length=1))

            def symkey(directory, key):
                return str(directory.join(key).relto(srcdir))

            def symvalue(directory, value):
                return str(directory.join(value)) + ("/" if value.endswith("/")
                                                     else "")

            outerdir = srcdir.join(random_str(min_length=1))
            innerdir = outerdir.join(random_str(min_length=1))
            targetdir = tmpdir.join(random_str(min_length=1))

            ext = random_str(min_length=1, max_length=8)
            os.makedirs(str(innerdir))

            for _ in range(random.randint(0, 10)):
                touch(str(innerdir.join("{0}.{1}".format(random_str(), ext))))

            path = str(outerdir.join(random_str(min_length=1)))
            touch(path)

            return {
                symkey(outerdir, "*/*." + ext): symvalue(targetdir, "a/b/c/"),
                symkey(outerdir, "*/*." + ext): symvalue(targetdir, "a/b/c"),
                symkey(outerdir, path): symvalue(targetdir, "a/b/d"),
            }

        cider = Cider(False,
                      debug,
                      verbose,
                      cider_dir=str(tmpdir),
                      support_dir=str(tmpdir.join(".cache")))
        cider.mklink = MagicMock(return_value=True)

        for srcglob, target in generate_symlinks().items():
            invalid = not isdirname(target) and ("*" in srcglob
                                                 or "?" in srcglob)
            old_targets = cider._cached_targets()  # pylint:disable=W0212
            cider.read_bootstrap = MagicMock(
                return_value={"symlinks": {
                    srcglob: target
                }})

            with pytest.raises(SymlinkError) if invalid else empty():
                new_targets = set(cider.relink(force))
                for src in iglob(srcglob):
                    cider.mklink.assert_called_with(
                        src, cider.expandtarget(src, target))

                assert os.path.isdir(os.path.dirname(target))
                for dead_target in set(old_targets) - new_targets:
                    assert not os.path.exists(dead_target)

                new_cache = cider._cached_targets()  # pylint:disable=W0212
                assert new_targets == set(new_cache) & new_targets
コード例 #32
0
ファイル: test_core.py プロジェクト: timgates42/cider
 def test_read_bootstrap(self, tmpdir, cask, debug, verbose, data):
     with patch("cider.core.read_config") as mock:
         cider = Cider(cask, debug, verbose, cider_dir=str(tmpdir))
         mock.return_value = data
         assert cider.read_bootstrap() == data
         mock.assert_called_with(cider.bootstrap_file, {})
コード例 #33
0
ファイル: test_core.py プロジェクト: timgates42/cider
 def test_remove_default(self, tmpdir, debug, verbose, domain, key):
     cider = Cider(False, debug, verbose, cider_dir=str(tmpdir))
     cider.defaults = MagicMock()
     cider.remove_default(domain, key)
     cider.defaults.delete.assert_called_with(domain, key)
     assert key not in cider.read_defaults().get(domain, [])
コード例 #34
0
class RL:
    def __init__(self, device, args):
        self.device = device
        self.args = args
        self.batch_size = args.batch_size
        self.checkpoint_path = os.path.join(args.checkpoint_dir, 'generator.pth')
        if not os.path.isdir(args.checkpoint_dir):
            os.mkdir(args.checkpoint_dir)
        self.generator = Generator(args).to(self.device)
        self.sequence_loss = SequenceLoss()
        self.reinforce_loss = ReinforceLoss()
        self.optimizer = optim.Adam(self.generator.parameters(), lr=args.learning_rate)
        self.evaluator = Evaluator('val', self.device, args)
        self.cider = Cider(args)
        dataset = CaptionDataset('train', args)
        self.loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True, num_workers=4)

    def train(self):
        if self.args.load_generator:
            self.generator.load_state_dict(torch.load(self.checkpoint_path))
        else:
            self._train_xe()
        self._train_rl()

    def _train_xe(self):
        iter = 0
        for epoch in range(self.args.xe_epochs):
            self._decay_learning_rate(epoch)
            self.generator.train()
            for data in self.loader:
                for name, item in data.items():
                    data[name] = item.to(self.device)
                self.generator.zero_grad()
                probs = self.generator(data['fc_feats'], data['att_feats'], data['att_masks'], data['labels'])
                loss = self.sequence_loss(probs, data['labels'])
                loss.backward()
                self._clip_gradient()
                self.optimizer.step()
                print('iter {}, epoch {}, loss {:.3f}'.format(iter, epoch, loss.item()))
                iter += 1
            self.evaluator.evaluate_generator(self.generator)
            torch.save(self.generator.state_dict(), self.checkpoint_path)

    def _train_rl(self):
        iter = 0
        for epoch in range(self.args.xe_epochs, self.args.xe_epochs + self.args.rl_epochs):
            self._decay_learning_rate(epoch)
            for data in self.loader:
                self.generator.train()
                for name, item in data.items():
                    data[name] = item.to(self.device)
                self.generator.zero_grad()
                loss, reward = self._rl_core1(data)
                loss.backward()
                self._clip_gradient()
                self.optimizer.step()
                print('iter {}, epoch {}, cider score {:.3f}'.format(iter, epoch, reward.mean().item()))
                iter += 1
            self.evaluator.evaluate_generator(self.generator)
            torch.save(self.generator.state_dict(), self.checkpoint_path)

    def _get_reward(self, data, seqs):
        scores = self.cider.get_scores(seqs.cpu().numpy(), data['images'].cpu().numpy())
        reward = torch.tensor(scores, dtype=torch.float, device=self.device)
        return reward

    def _clip_gradient(self):
        for group in self.optimizer.param_groups:
            for param in group['params']:
                param.grad.data.clamp_(-self.args.grad_clip_threshold, self.args.grad_clip_threshold)

    def _decay_learning_rate(self, epoch):
        if epoch % self.args.learning_rate_decay_every == 0:
            learning_rate = self.args.learning_rate * (self.args.learning_rate_decay_rate ** (epoch // self.args.learning_rate_decay_every))
            for group in self.optimizer.param_groups:
                group['lr'] = learning_rate
            print('learning rate: {}'.format(learning_rate))

    def _rl_core1(self, data):
        seqs, probs = self.generator.sample(data['fc_feats'], data['att_feats'], data['att_masks'])
        greedy_seqs = self.generator.greedy_decode(data['fc_feats'], data['att_feats'], data['att_masks'])
        reward = self._get_reward(data, seqs)
        baseline = self._get_reward(data, greedy_seqs)
        loss = self.reinforce_loss(reward, baseline, probs, seqs)
        return loss, reward

    def _rl_core2(self, data):
        num_samples = 8
        all_seqs = []
        all_probs = []
        all_reward = []
        for _ in range(num_samples):
            seqs, probs = self.generator.sample(data['fc_feats'], data['att_feats'], data['att_masks'])
            scores = self.cider.get_scores(seqs.cpu().numpy(), data['images'].cpu().numpy())
            reward = torch.tensor(scores, dtype=torch.float, device=self.device)
            all_seqs.append(seqs)
            all_probs.append(probs)
            all_reward.append(reward)
        all_seqs = torch.stack(all_seqs)
        all_probs = torch.stack(all_probs)
        all_reward = torch.stack(all_reward)
        seqs = all_seqs.view(-1, all_seqs.size(2))
        probs = all_probs.view(-1, all_probs.size(2), all_probs.size(3))
        baseline = all_reward.mean(0, keepdim=True).expand(num_samples, -1).contiguous().view(-1)
        reward = all_reward.view(-1)
        loss = self.reinforce_loss(reward, baseline, probs, seqs)
        return loss, reward
コード例 #35
0
ファイル: test_core.py プロジェクト: lorin/cider
    def test_relink(self, tmpdir, debug, verbose, force):
        """
        Tests that:
        1. Target directories are created.
        2. For each source in glob(key), mklink(src, expandtarget(src, target))
           is called.
        3. Previously-cached targets are removed.
        4. Cache is updated with new targets.
        """
        def generate_symlinks():
            srcdir = tmpdir.join(random_str(min_length=1))

            def symkey(directory, key):
                return str(directory.join(key).relto(srcdir))

            def symvalue(directory, value):
                return str(directory.join(value)) + (
                    "/" if value.endswith("/") else ""
                )

            outerdir = srcdir.join(random_str(min_length=1))
            innerdir = outerdir.join(random_str(min_length=1))
            targetdir = tmpdir.join(random_str(min_length=1))

            ext = random_str(min_length=1, max_length=8)
            os.makedirs(str(innerdir))

            for _ in range(random.randint(0, 10)):
                touch(str(innerdir.join("{0}.{1}".format(random_str(), ext))))

            path = str(outerdir.join(random_str(min_length=1)))
            touch(path)

            return {
                symkey(outerdir, "*/*." + ext): symvalue(targetdir, "a/b/c/"),
                symkey(outerdir, "*/*." + ext): symvalue(targetdir, "a/b/c"),
                symkey(outerdir, path): symvalue(targetdir, "a/b/d"),
            }

        cider = Cider(False, debug, verbose, cider_dir=str(tmpdir))
        cider.mklink = MagicMock(return_value=True)

        for srcglob, target in generate_symlinks().items():
            invalid = not isdirname(target) and ("*" in srcglob or
                                                 "?" in srcglob)
            old_targets = cider._cached_targets()  # pylint:disable=W0212
            cider.read_bootstrap = MagicMock(return_value={
                "symlinks": {srcglob: target}
            })

            with pytest.raises(SymlinkError) if invalid else empty():
                new_targets = set(cider.relink(force))
                for src in iglob(srcglob):
                    cider.mklink.assert_called_with(src, cider.expandtarget(
                        src, target
                    ))

                assert os.path.isdir(os.path.dirname(target))
                for dead_target in set(old_targets) - new_targets:
                    assert not os.path.exists(dead_target)

                new_cache = cider._cached_targets()  # pylint:disable=W0212
                assert new_targets == set(new_cache).intersection(new_targets)
コード例 #36
0
ファイル: test_core.py プロジェクト: a-b/cider
 def test_read_bootstrap(self, tmpdir, cask, debug, verbose, data):
     with patch("cider.core.read_config") as mock:
         cider = Cider(cask, debug, verbose, cider_dir=str(tmpdir))
         mock.return_value = data
         assert cider.read_bootstrap() == data
         mock.assert_called_with(cider.bootstrap_file, {})
コード例 #37
0
ファイル: test_core.py プロジェクト: lorin/cider
 def test_untap(self, tmpdir, debug, verbose, tap):
     cider = Cider(False, debug, verbose, cider_dir=str(tmpdir))
     cider.brew = MagicMock()
     cider.untap(tap)
     cider.brew.untap.assert_called_with(tap)
     assert tap not in cider.read_bootstrap().get("taps", [])
コード例 #38
0
ファイル: test_core.py プロジェクト: lorin/cider
 def test_remove_default(self, tmpdir, debug, verbose, domain, key):
     cider = Cider(False, debug, verbose, cider_dir=str(tmpdir))
     cider.defaults = MagicMock()
     cider.remove_default(domain, key)
     cider.defaults.delete.assert_called_with(domain, key)
     assert key not in cider.read_defaults().get(domain, [])