示例#1
0
class Item(BaseSprite, metaclass=abc.ABCMeta):
    """An abstract base class for sprites that represent in-game items."""
    # Number of pixels up and down that item will bob.
    BOB_RANGE = 15
    BOB_SPEED = 0.2

    def __init__(self, x: float, y: float, image: str, sound: str,
                 groups: typing.Dict[str, pg.sprite.Group]):
        BaseSprite.__init__(self, image, groups, groups['all'],
                            groups['items'])
        self.rect.center = (x, y)
        self._sfx = sound
        self._spawn_pos = pg.math.Vector2(x, y)
        self._effect_timer = Timer()
        # Default duration is 0.
        self._duration = 0
        # Tween function maps integer steps to values between 0 and 1.
        self._tween = tween.easeInOutSine
        self._step = 0
        self._direction = 1

    @property
    def spawn_pos(self) -> pg.math.Vector2:
        return self._spawn_pos

    def update(self, dt: float) -> None:
        """Floating animation for an item that has spawned. Credits to Chris Bradfield from KidsCanCode."""
        # Shift bobbing y offset to bob about item's original center.
        offset = Item.BOB_RANGE * (self._tween(self._step / Item.BOB_RANGE) -
                                   0.5)
        self.rect.centery = self._spawn_pos.y + offset * self._direction
        self._step += Item.BOB_SPEED
        # Reverse bobbing direction when item returns to center.
        if self._step > Item.BOB_RANGE:
            self._step = 0
            self._direction *= -1

    def activate(self, sprite: pg.sprite.Sprite) -> None:
        """Applies the item's effect upon pickup and causes it to be stop being drawn."""
        self._apply_effect(sprite)
        self._effect_timer.restart()
        # Applies to items with non-zero duration.
        sfx_loader.play(self._sfx)
        # Make sure it doesn't get drawn anymore after the effect has been applied.
        super().kill()

    def effect_subsided(self) -> bool:
        """Checks if the item's effect should subside."""
        return self._effect_timer.elapsed() > self._duration

    @abc.abstractmethod
    def _apply_effect(self, sprite) -> None:
        """Effect that is applied on item as long as the timer has not subsided."""
        pass

    def remove_effect(self, sprite) -> None:
        """Causes an item with a non-zero duration to have its effect removed from a sprite at the end."""
        pass
示例#2
0
class Tank(BaseSprite, MoveNonlinearMixin, RotateMixin, DamageMixin):
    """Sprite class that models a Tank object."""
    KNOCK_BACK = 100

    _SPEED_CUTOFF = 100
    _TRACK_DELAY = 100

    BIG = "big"
    LARGE = "large"
    HUGE = "huge"

    def __init__(self, x: float, y: float, img: str,
                 all_groups: typing.Dict[str, pg.sprite.Group]):
        """Initializes the tank's sprite with no barrels to shoot from.

        :param x: x coordinate for centering the sprite's position.
        :param y: y coordinate for centering the sprite's position.
        :param img: filename for the sprite's tank image.
        :param all_groups: A dictionary of all of the game world's sprite groups.
        """
        self._layer = cfg.TANK_LAYER
        BaseSprite.__init__(self, img, all_groups, all_groups['all'],
                            all_groups['tanks'], all_groups['damageable'])
        MoveNonlinearMixin.__init__(self, x, y)
        RotateMixin.__init__(self)
        DamageMixin.__init__(self, self.hit_rect)
        self.rect.center = (x, y)
        self.MAX_ACCELERATION = 768
        self._barrels = []
        self._items = []
        self._track_timer = Timer()

    def update(self, dt: float) -> None:
        """Rotates, moves, and handles any active in-game items that have some effect.

        :param dt: Time elapsed since the tank's last update.
        :return: None
        """
        self.rotate(dt)
        self.move(dt)
        for item in self._items:
            if item.effect_subsided():
                item.remove_effect(self)
                self._items.remove(item)
        if self.vel.length_squared(
        ) > Tank._SPEED_CUTOFF and self._track_timer.elapsed(
        ) > Tank._TRACK_DELAY:
            self._spawn_tracks()

    @property
    def range(self) -> float:
        """The shooting distance of the tank, as given by the tank's barrels."""
        return self._barrels[0].range

    @property
    def color(self) -> str:
        """Returns a string representing the color of one of the tank's barrels."""
        return self._barrels[0].color

    def pickup(self, item) -> None:
        """Activates an item that this Tank object has picked up (collided with) and saves it.

        :param item: Item sprite that can be used to apply an effect on the Tank object.
        :return: None
        """
        item.activate(self)
        self._items.append(item)

    def equip_barrel(self, barrel: Barrel) -> None:
        """Equips a new barrel to this tank."""
        self._barrels.append(barrel)

    def _spawn_tracks(self) -> None:
        """Spawns track sprites as the tank object moves around the map."""
        Tracks(*self.pos, self.hit_rect.height, self.hit_rect.height, self.rot,
               self.all_groups)
        self._track_timer.restart()

    def rotate_barrel(self, aim_direction: float):
        """Rotates the all of the tank's barrels in a direction indicated by aim_direction."""
        for barrel in self._barrels:
            barrel.rot = aim_direction
            barrel.rotate()

    def ammo_count(self) -> int:
        """Returns the ammo count of the tank's barrels."""
        return self._barrels[0].ammo_count

    def fire(self) -> None:
        """Fires a bullet from the Tank's barrels."""
        for barrel in self._barrels:
            barrel.fire()

    def reload(self) -> None:
        """Reloads bullets for each of the bullets."""
        for barrel in self._barrels:
            barrel.reload()

    def kill(self) -> None:
        """Removes this sprite and its barrels from all sprite groups."""
        for barrel in self._barrels:
            barrel.kill()
        for item in self._items:
            item.kill()
        super().kill()

    @classmethod
    def color_tank(cls, x: float, y: float, color: str, category: str,
                   groups: typing.Dict[str, pg.sprite.Group]):
        """Factory method for creating Tank objects."""
        tank = cls(x, y, f"tankBody_{color}_outline.png", groups)
        offset = pg.math.Vector2(tank.hit_rect.height // 3, 0)
        barrel = Barrel.create_color_barrel(tank, offset, color.capitalize(),
                                            category, groups)
        tank.equip_barrel(barrel)
        return tank

    @classmethod
    def enemy(cls, x: float, y: float, size: str,
              groups: typing.Dict[str, pg.sprite.Group]) -> 'Tank':
        """Returns a enemy tank class depending on the size parameter."""
        if size == cls.BIG:
            return cls.big_tank(x, y, groups)
        elif size == cls.LARGE:
            return cls.large_tank(x, y, groups)
        elif size == cls.HUGE:
            return cls.huge_tank(x, y, groups)
        raise ValueError(f"Invalid size attribute: {size}")

    @classmethod
    def big_tank(cls, x: float, y: float,
                 groups: typing.Dict[str, pg.sprite.Group]) -> 'Tank':
        """Returns the a 'big' enemy tank."""
        tank = cls(x, y, "tankBody_bigRed.png", groups)
        for y_offset in (-10, 10):
            barrel = Barrel.create_special(tank,
                                           pg.math.Vector2(0, y_offset),
                                           "Dark",
                                           groups,
                                           special=1)
            tank.equip_barrel(barrel)
        return tank

    @classmethod
    def large_tank(cls, x: float, y: float,
                   groups: typing.Dict[str, pg.sprite.Group]) -> 'Tank':
        """Returns the a 'large' enemy tank."""
        tank = cls(x, y, "tankBody_darkLarge.png", groups)
        tank.MAX_ACCELERATION *= 0.9
        for y_offset in (-10, 10):
            barrel = Barrel.create_special(tank,
                                           pg.math.Vector2(0, y_offset),
                                           "Dark",
                                           groups,
                                           special=4)
            tank.equip_barrel(barrel)
        return tank

    @classmethod
    def huge_tank(cls, x: float, y: float,
                  groups: typing.Dict[str, pg.sprite.Group]) -> 'Tank':
        """Returns the a 'huge' enemy tank."""
        tank = cls(x, y, "tankBody_huge_outline.png", groups)
        tank.MAX_ACCELERATION *= 0.8
        for y_offset in (-10, 10):
            barrel = Barrel.create_special(tank,
                                           pg.math.Vector2(20, y_offset),
                                           "Dark",
                                           groups,
                                           special=4)
            tank.equip_barrel(barrel)
        barrel = Barrel.create_special(tank,
                                       pg.math.Vector2(-10, 0),
                                       "Dark",
                                       groups,
                                       special=1)
        tank.equip_barrel(barrel)
        return tank
示例#3
0
def train(cfg):
    startup_prog = fluid.Program()
    train_prog = fluid.Program()
    drop_last = True
    dataset = build_dataset(cfg.DATASET.DATASET_NAME,
        file_list=cfg.DATASET.TRAIN_FILE_LIST,
        mode=ModelPhase.TRAIN,
        shuffle=True,
        data_dir=cfg.DATASET.DATA_DIR,
        base_size= cfg.DATAAUG.BASE_SIZE, crop_size= cfg.DATAAUG.CROP_SIZE, rand_scale=True)

    def data_generator():
        if args.use_mpio:
            data_gen = dataset.multiprocess_generator(
                num_processes=cfg.DATALOADER.NUM_WORKERS,
                max_queue_size=cfg.DATALOADER.BUF_SIZE)
        else:
            data_gen = dataset.generator()

        batch_data = []
        for b in data_gen:
            batch_data.append(b)
            if len(batch_data) == (cfg.TRAIN_BATCH_SIZE // cfg.NUM_TRAINERS):
                for item in batch_data:
                    yield item[0], item[1], item[2]
                batch_data = []
        # If use sync batch norm strategy, drop last batch if number of samples
        # in batch_data is less then cfg.BATCH_SIZE to avoid NCCL hang issues
        if not cfg.TRAIN.SYNC_BATCH_NORM:
            for item in batch_data:
                yield item[0], item[1], item[2]

    # Get device environment
    gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
    place = fluid.CUDAPlace(gpu_id) if args.use_gpu else fluid.CPUPlace()
    places = fluid.cuda_places() if args.use_gpu else fluid.cpu_places()

    # Get number of GPU
    dev_count = cfg.NUM_TRAINERS if cfg.NUM_TRAINERS > 1 else len(places)
    print_info("#device count: {}".format(dev_count))
    cfg.TRAIN_BATCH_SIZE = dev_count * int(cfg.TRAIN_BATCH_SIZE_PER_GPU)
    print_info("#train_batch_size: {}".format(cfg.TRAIN_BATCH_SIZE))
    print_info("#batch_size_per_dev: {}".format(cfg.TRAIN_BATCH_SIZE_PER_GPU))

    py_reader, avg_loss, lr, pred, grts, masks = build_model(
        train_prog, startup_prog, phase=ModelPhase.TRAIN)
    py_reader.decorate_sample_generator(
        data_generator, batch_size=cfg.TRAIN_BATCH_SIZE_PER_GPU, drop_last=drop_last)

    exe = fluid.Executor(place)
    exe.run(startup_prog)

    exec_strategy = fluid.ExecutionStrategy()
    # Clear temporary variables every 100 iteration
    if args.use_gpu:
        exec_strategy.num_threads = fluid.core.get_cuda_device_count()
    exec_strategy.num_iteration_per_drop_scope = 100
    build_strategy = fluid.BuildStrategy()

    if cfg.NUM_TRAINERS > 1 and args.use_gpu:
        dist_utils.prepare_for_multi_process(exe, build_strategy, train_prog)
        exec_strategy.num_threads = 1

    if cfg.TRAIN.SYNC_BATCH_NORM and args.use_gpu:
        if dev_count > 1:
            # Apply sync batch norm strategy
            print_info("Sync BatchNorm strategy is effective.")
            build_strategy.sync_batch_norm = True
        else:
            print_info(
                "Sync BatchNorm strategy will not be effective if GPU device"
                " count <= 1")
    compiled_train_prog = fluid.CompiledProgram(train_prog).with_data_parallel(
        loss_name=avg_loss.name,
        exec_strategy=exec_strategy,
        build_strategy=build_strategy)

    # Resume training
    begin_epoch = cfg.SOLVER.BEGIN_EPOCH
    if cfg.TRAIN.RESUME_MODEL_DIR:
        begin_epoch = load_checkpoint(exe, train_prog)
    # Load pretrained model
    elif os.path.exists(cfg.TRAIN.PRETRAINED_MODEL_DIR):
        print_info('Pretrained model dir: ', cfg.TRAIN.PRETRAINED_MODEL_DIR)
        load_vars = []
        load_fail_vars = []

        def var_shape_matched(var, shape):
            """
            Check whehter persitable variable shape is match with current network
            """
            var_exist = os.path.exists(
                os.path.join(cfg.TRAIN.PRETRAINED_MODEL_DIR, var.name))
            if var_exist:
                var_shape = parse_shape_from_file(
                    os.path.join(cfg.TRAIN.PRETRAINED_MODEL_DIR, var.name))
                return var_shape == shape
            return False

        for x in train_prog.list_vars():
            if isinstance(x, fluid.framework.Parameter):
                shape = tuple(fluid.global_scope().find_var(
                    x.name).get_tensor().shape())
                if var_shape_matched(x, shape):
                    load_vars.append(x)
                else:
                    load_fail_vars.append(x)

        fluid.io.load_vars(
            exe, dirname=cfg.TRAIN.PRETRAINED_MODEL_DIR, vars=load_vars)
        for var in load_vars:
            print_info("Parameter[{}] loaded sucessfully!".format(var.name))
        for var in load_fail_vars:
            print_info(
                "Parameter[{}] don't exist or shape does not match current network, skip"
                " to load it.".format(var.name))
        print_info("{}/{} pretrained parameters loaded successfully!".format(
            len(load_vars),
            len(load_vars) + len(load_fail_vars)))
    else:
        print_info(
            'Pretrained model dir {} not exists, training from scratch...'.
            format(cfg.TRAIN.PRETRAINED_MODEL_DIR))

    fetch_list = [avg_loss.name, lr.name]
    if args.debug:
        # Fetch more variable info and use streaming confusion matrix to
        # calculate IoU results if in debug mode
        np.set_printoptions(
            precision=4, suppress=True, linewidth=160, floatmode="fixed")
        fetch_list.extend([pred.name, grts.name, masks.name])
        cm = ConfusionMatrix(cfg.DATASET.NUM_CLASSES, streaming=True)

    if args.use_vdl:
        if not args.vdl_log_dir:
            print_info("Please specify the log directory by --vdl_log_dir.")
            exit(1)

        from visualdl import LogWriter
        log_writer = LogWriter(args.vdl_log_dir)

    # trainer_id = int(os.getenv("PADDLE_TRAINER_ID", 0))
    # num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
    step = 0
    all_step = cfg.DATASET.TRAIN_TOTAL_IMAGES // cfg.TRAIN_BATCH_SIZE
    if cfg.DATASET.TRAIN_TOTAL_IMAGES % cfg.TRAIN_BATCH_SIZE and drop_last != True:
        all_step += 1
    all_step *= (cfg.SOLVER.NUM_EPOCHS - begin_epoch + 1)

    avg_loss = 0.0
    timer = Timer()
    timer.start()
    if begin_epoch > cfg.SOLVER.NUM_EPOCHS:
        raise ValueError(
            ("begin epoch[{}] is larger than cfg.SOLVER.NUM_EPOCHS[{}]").format(
                begin_epoch, cfg.SOLVER.NUM_EPOCHS))

    if args.use_mpio:
        print_info("Use multiprocess reader")
    else:
        print_info("Use multi-thread reader")

    for epoch in range(begin_epoch, cfg.SOLVER.NUM_EPOCHS + 1):
        py_reader.start()
        while True:
            try:
                if args.debug:
                    # Print category IoU and accuracy to check whether the
                    # traning process is corresponed to expectation
                    loss, lr, pred, grts, masks = exe.run(
                        program=compiled_train_prog,
                        fetch_list=fetch_list,
                        return_numpy=True)
                    cm.calculate(pred, grts, masks)
                    avg_loss += np.mean(np.array(loss))
                    step += 1

                    if step % args.log_steps == 0:
                        speed = args.log_steps / timer.elapsed_time()
                        avg_loss /= args.log_steps
                        category_acc, mean_acc = cm.accuracy()
                        category_iou, mean_iou = cm.mean_iou()

                        print_info((
                            "epoch={}/{} step={}/{} lr={:.5f} loss={:.4f} acc={:.5f} mIoU={:.5f} step/sec={:.3f} | ETA {}"
                        ).format(epoch, cfg.SOLVER.NUM_EPOCHS, step, all_step, lr[0], avg_loss, mean_acc,
                                 mean_iou, speed,
                                 calculate_eta(all_step - step, speed)))
                        print_info("Category IoU: ", category_iou)
                        print_info("Category Acc: ", category_acc)
                        if args.use_vdl:
                            log_writer.add_scalar('Train/mean_iou', mean_iou,
                                                  step)
                            log_writer.add_scalar('Train/mean_acc', mean_acc,
                                                  step)
                            log_writer.add_scalar('Train/loss', avg_loss,
                                                  step)
                            log_writer.add_scalar('Train/lr', lr[0],
                                                  step)
                            log_writer.add_scalar('Train/step/sec', speed,
                                                  step)
                        sys.stdout.flush()
                        avg_loss = 0.0
                        cm.zero_matrix()
                        timer.restart()
                else:
                    # If not in debug mode, avoid unnessary log and calculate
                    loss, lr = exe.run(
                        program=compiled_train_prog,
                        fetch_list=fetch_list,
                        return_numpy=True)
                    avg_loss += np.mean(np.array(loss))
                    step += 1

                    if step % args.log_steps == 0 and cfg.TRAINER_ID == 0:
                        avg_loss /= args.log_steps
                        speed = args.log_steps / timer.elapsed_time()
                        print((
                            "epoch={}/{} step={}/{} lr={:.5f} loss={:.4f} step/sec={:.3f} | ETA {}"
                        ).format(epoch, cfg.SOLVER.NUM_EPOCHS, global_step, all_step, lr[0], avg_loss, speed,
                                 calculate_eta(all_step - global_step, speed)))
                        if args.use_vdl:
                            log_writer.add_scalar('Train/loss', avg_loss,
                                                  step)
                            log_writer.add_scalar('Train/lr', lr[0],
                                                  step)
                            log_writer.add_scalar('Train/speed', speed,
                                                  step)
                        sys.stdout.flush()
                        avg_loss = 0.0
                        timer.restart()

            except fluid.core.EOFException:
                py_reader.reset()
                break
            except Exception as e:
                print(e)

        if epoch % cfg.TRAIN.SNAPSHOT_EPOCH == 0 and cfg.TRAINER_ID == 0:
            ckpt_dir = save_checkpoint(exe, train_prog, epoch)

            if args.do_eval:
                print("Evaluation start")
                _, mean_iou, _, mean_acc = evaluate(
                    cfg=cfg,
                    ckpt_dir=ckpt_dir,
                    use_gpu=args.use_gpu,
                    use_mpio=args.use_mpio)
                if args.use_vdl:
                    log_writer.add_scalar('Evaluate/mean_iou', mean_iou,
                                          step)
                    log_writer.add_scalar('Evaluate/mean_acc', mean_acc,
                                          step)

            # Use VisualDL to visualize results
            if args.use_vdl and cfg.DATASET.VIS_FILE_LIST is not None:
                visualize(
                    cfg=cfg,
                    use_gpu=args.use_gpu,
                    vis_file_list=cfg.DATASET.VIS_FILE_LIST,
                    vis_dir="visual",
                    ckpt_dir=ckpt_dir,
                    log_writer=log_writer)

    # save final model
    if cfg.TRAINER_ID == 0:
        save_checkpoint(exe, train_prog, 'final')

    if args.use_vdl:
        log_writer.close()
示例#4
0
class Level:
    """Class that creates, draws, and updates the game world, including the map and all sprites."""
    _ITEM_RESPAWN_TIME = 30000  # 1 minute.

    def __init__(self, level_file: str):
        """Creates a map and creates all of the sprites in it.

        :param level_file: Filename of level file to load from the configuration file's map folder.
        """
        # Create the tiled map surface.
        map_loader = TiledMapLoader(level_file)
        self.image = map_loader.make_map()
        self.rect = self.image.get_rect()
        self._groups = {
            'all': pg.sprite.LayeredUpdates(),
            'tanks': pg.sprite.Group(),
            'damageable': pg.sprite.Group(),
            'bullets': pg.sprite.Group(),
            'obstacles': pg.sprite.Group(),
            'items': pg.sprite.Group(),
            'item_boxes': pg.sprite.Group()
        }
        self._player = None
        self._camera = None
        self._ai_mobs = []
        self._item_spawn_positions = []
        self._item_spawn_timer = Timer()
        # Initialize all sprites in game world.
        self._init_sprites(map_loader.tiled_map.objects)

    def _init_sprites(self, objects: pytmx.TiledObjectGroup) -> None:
        """Initializes all of the pygame sprites in this level's map.

        :param objects: Iterator for accessing the properties of all game objects to be created.
        :return: None

        Expects to find a single 'player' and 'enemy_tank' object, and possible more than one
        of any other object. A sprite is created out of each object and added to the appropriate
        group. A boundary for the game world is also created to keep the sprites constrained.
        """
        game_objects = {}
        for t_obj in objects:
            # Expect single enemy tank and multiple of other objects.
            if t_obj.name == 'enemy_tank' or t_obj.name == "player":
                game_objects[t_obj.name] = t_obj
            else:
                game_objects.setdefault(t_obj.name, []).append(t_obj)

        # Create the player and world camera.
        p = game_objects.get('player')
        tank = Tank.color_tank(p.x, p.y, p.color, p.category,
                               self._groups)  # Make a tank factory.
        self._player = PlayerCtrl(tank)
        self._camera = Camera(self.rect.width, self.rect.height,
                              self._player.tank)

        # Spawn single enemy tank.
        t = game_objects.get('enemy_tank')
        tank = Tank.enemy(t.x, t.y, t.size,
                          self._groups)  # Make a tank factory.
        ai_patrol_points = game_objects.get('ai_patrol_point')
        ai_boss = AITankCtrl(tank, ai_patrol_points, self._player.tank)
        self._ai_mobs.append(ai_boss)

        # Spawn turrets.
        for t in game_objects.get('turret'):
            turret = Turret(t.x, t.y, t.category, t.special, self._groups)
            self._ai_mobs.append(
                AITurretCtrl(turret, ai_boss, self._player.tank))

        # Spawn obstacles that one can collide with.
        for tree in game_objects.get('small_tree'):
            Tree(tree.x, tree.y, self._groups)

        # Spawn items boxes that can be destroyed to get an item.
        for box in game_objects.get('box_spawn'):
            self._item_spawn_positions.append((box.x, box.y))
            ItemBox.spawn(box.x, box.y, self._groups)

        # Creates the boundaries of the game world.
        BoundaryWall(x=0,
                     y=0,
                     width=self.rect.width,
                     height=1,
                     all_groups=self._groups)  # Top
        BoundaryWall(x=0,
                     y=self.rect.height,
                     width=self.rect.width,
                     height=1,
                     all_groups=self._groups)  # Bottom
        BoundaryWall(x=0,
                     y=0,
                     width=1,
                     height=self.rect.height,
                     all_groups=self._groups)  # Left
        BoundaryWall(x=self.rect.width,
                     y=0,
                     width=1,
                     height=self.rect.height,
                     all_groups=self._groups)  # Right

    def _can_spawn_item(self) -> bool:
        """"Checks if a new item can be spawned."""
        return self._item_spawn_timer.elapsed() > Level._ITEM_RESPAWN_TIME and \
            len(self._groups['items']) + len(self._groups['item_boxes']) < len(self._item_spawn_positions)

    def is_player_alive(self) -> bool:
        """Checks if the player's tank has been defeated."""
        return self._player.tank.alive()

    def mob_count(self) -> int:
        """Checks if all the AI mobs have been defeated."""
        return len(self._ai_mobs)

    def process_inputs(self) -> None:
        """Handles keys and clicks that affect the game world."""
        self._player.handle_keys()
        # Convert mouse coordinates to world coordinates.
        mouse_x, mouse_y = pg.mouse.get_pos()
        mouse_world_pos = pg.math.Vector2(mouse_x + self._camera.rect.x,
                                          mouse_y + self._camera.rect.y)
        self._player.handle_mouse(mouse_world_pos)

    def update(self, dt: float) -> None:
        """Updates the game world's AI, sprites, camera, and resolves collisions.

        :param dt: time elapsed since the last update of the game world.
        :return: None
        """
        for ai in self._ai_mobs:
            ai.update(dt)
        self._groups['all'].update(dt)
        # Update list of ai mobs.
        self._camera.update()

        game_items_count = len(self._groups['items'])
        collision_handler.handle_collisions(self._groups)
        if game_items_count > 0 and len(
                self._groups['items']) < game_items_count:
            self._item_spawn_timer.restart()
        # See if it's time to spawn a new item.
        if self._can_spawn_item():
            available_positions = self._item_spawn_positions.copy()
            for x, y in self._item_spawn_positions:
                for sprite in self._groups['items']:
                    if sprite.spawn_pos.x == x and sprite.spawn_pos == y and (
                            x, y) in available_positions:
                        available_positions.remove((x, y))
                for sprite in self._groups['item_boxes']:
                    if sprite.rect.center == (x, y) and (
                            x, y) in available_positions:
                        available_positions.remove((x, y))
            if available_positions:
                x, y, = random.choice(available_positions)
                ItemBox.spawn(x, y, self._groups)

        # Filter out any AIs that have been defeated.
        self._ai_mobs = [ai for ai in self._ai_mobs if ai.sprite.alive()]

    def draw(self, screen: pg.Surface) -> None:
        """Draws every sprite in the game world, as well as heads-up display elements.

        :param screen: The screen surface that the world's elements will be drawn to.
        :return: None
        """
        # Draw the map.
        screen.blit(self.image, self._camera.apply(self.rect))
        # Draw all sprites.
        for sprite in self._groups['all']:
            screen.blit(sprite.image, self._camera.apply(sprite.rect))
            # pg.draw.rect(screen, (255, 255, 255), self._camera.apply(sprite.hit_rect), 1)

        # Draw HUD.
        for ai in self._ai_mobs:
            ai.sprite.draw_health(screen, self._camera)
        self._player.draw_hud(screen, self._camera)
示例#5
0
class Barrel(BaseSprite, RotateMixin):
    """Sprite class that models a Barrel object."""
    _FIRE_SFX = 'shoot.wav'

    def __init__(self, tank, offset: pg.math.Vector2, image: str, color: str,
                 category: str, all_groups: typing.Dict[str, pg.sprite.Group]):
        """Fills up the Barrel's ammo and centers its position on its parent."""
        self._layer = cfg.BARREL_LAYER
        BaseSprite.__init__(self, image, all_groups, all_groups['all'])
        RotateMixin.__init__(self)
        # Bullet parameters.
        self._category = category
        self._color = color
        self._ammo_count = _STATS[self._category]["max_ammo"]

        # Parameters used for barrel position.
        self._parent = tank
        self.rect.center = tank.rect.center
        self._offset = offset

        self._fire_delay = _STATS[self._category]["fire_delay"]
        self._fire_timer = Timer()

    @property
    def color(self) -> str:
        """Returns a string representing the barrel's color."""
        return self._color

    @property
    def ammo_count(self) -> int:
        """Returns the current ammo count for this barrel."""
        return self._ammo_count

    @property
    def range(self) -> float:
        """Returns the fire range of the barrel."""
        return Bullet.range(self._category)

    @property
    def fire_delay(self) -> float:
        """Returns the number of milliseconds until barrel can fire again."""
        return self._fire_delay

    def update(self, dt: float) -> None:
        """Updates the barrel's position by centering on the parent's position (accounting for the offset)."""
        vec = self._offset.rotate(-self.rot)
        self.rect.centerx = self._parent.rect.centerx + vec.x
        self.rect.centery = self._parent.rect.centery + vec.y

    def fire(self) -> None:
        """Fires a Bullet if enough time has passed and if there's ammo."""
        if self._ammo_count > 0 and self._fire_timer.elapsed() > self._fire_delay:
            self._spawn_bullet()
            sfx_loader.play(Barrel._FIRE_SFX)
            self._fire_timer.restart()

    def _spawn_bullet(self) -> None:
        """Spawns a Bullet object from the Barrel's nozzle."""
        fire_pos = pg.math.Vector2(self.hit_rect.height, 0).rotate(-self.rot)
        fire_pos.xy += self.rect.center
        Bullet(fire_pos.x, fire_pos.y, self.rot, self._color, self._category, self._parent, self.all_groups)
        MuzzleFlash(*fire_pos, self.rot, self.all_groups)
        self._ammo_count -= 1

    def reload(self) -> None:
        """Reloads the Barrel to have maximum ammo"""
        self._ammo_count = _STATS[self._category]["max_ammo"]

    def kill(self) -> None:
        self._parent = None
        super().kill()

    @classmethod
    def create_color_barrel(cls, tank, offset: pg.math.Vector2, color: str, category: str,
                            groups: typing.Dict[str, pg.sprite.Group]) -> 'Barrel':
        """Creates a color barrel object."""
        return cls(tank, offset, f"tank{color.capitalize()}_barrel{cfg.CATEGORY[category]}.png", color, category, groups)

    @classmethod
    def create_special(cls, tank, offset: pg.math.Vector2, color: str,
                       all_groups: typing.Dict[str, pg.sprite.Group], special) -> 'Barrel':
        """Creates a special barrel object."""
        barrel = cls(tank, offset, f"specialBarrel{special}.png", color, "standard", all_groups)
        helpers.flip(barrel, orig_image=barrel.image, x_reflect=True, y_reflect=False)
        return barrel