Beispiel #1
0
 def numpy_apply_policies(arglist):
     x, cta, probe = arglist
     if x.ndim == 3:
         assert probe
         policy = cta.policy(probe=True)
         return dict(policy=policy,
                     probe=ctaugment.apply(x, policy),
                     image=x)
     assert not probe
     cutout_policy = lambda: cta.policy(probe=False) + [ctaugment.OP('cutout', (1,))]
     return dict(image=np.stack([x[0]] + [ctaugment.apply(y, cutout_policy()) for y in x[1:]]).astype('f'))
 def numpy_apply_policies(arglist):
     x, cta, probe, anchoring = arglist
     if x.ndim == 3:
         assert probe
         policy = cta.policy(probe=True)
         return dict(policy=policy,
                     probe=ctaugment.apply(x, policy),
                     image=x)
     assert not probe
     cutout_policy = lambda: cta.policy(probe=False
                                        ) + [ctaugment.OP('cutout', (1, ))]
     aug0 = [x[0]] if anchoring[0] == 'weak' else [
         ctaugment.apply(x[0], cutout_policy())
     ]
     aug1 = [y for y in x[1:]] if anchoring[1] == 'weak' else [
         ctaugment.apply(y, cutout_policy()) for y in x[1:]
     ]
     return dict(image=np.stack(aug0 + aug1).astype('f'))