コード例 #1
0
def replay_collate_fn(samples, max_timesteps=50, weighting='log'):
    """
    Arguments:
        samples: a list of tuples (x: dict, y: int)
        max_timesteps: The maximum number of timesteps to allow.
            Frames will be sampled with normalized exponential weights against time.
        weighting: str, one of 'exp' or 'log'.
    """

    if weighting == 'exp':
        weight_fn = np.exp
    elif weighting == 'log':
        weight_fn = lambda x: np.log(1 + x)
    else:
        raise NotImplementedError

    batched_inputs = collections.defaultdict(list)
    batched_targets = []
    for input_dict, target in samples:

        timesteps, _, _ = input_dict.get('unit_type').shape
        if timesteps < max_timesteps:
            max_timesteps = timesteps

        weights = [weight_fn(i) for i in range(timesteps)]
        weights /= np.sum(weights)
        try:
            timestep_indices = np.random.choice(timesteps, max_timesteps, replace=False, p=weights)
            #timestep_indices = WeightedRandomSampler(weights, max_timesteps, replacement=False)
            timestep_indices = sorted(list(timestep_indices), reverse=False)
        except ValueError as e:
            print(f"Timesteps: {timesteps}")
            print(f"Max timesteps: {max_timesteps}")
            print(f"Weights: {len(weights)}")
            raise ValueError(str(e))

        for name, feature in input_dict.items():
            sampled_feature = feature[timestep_indices]
            if name == 'unit_type':
                mask = sampled_feature > SPATIAL_FEATURES._asdict()['unit_type'].scale + 1
                sampled_feature[mask] = 0
            #if name != 'height_map':
            #    sampled_feature = sampled_feature.astype(np.int32)
            sampled_feature = torch.from_numpy(sampled_feature)
            batched_inputs[name].append(sampled_feature)

        batched_targets.append(target)

    batched_tensor_inputs = {}
    for name, inp in batched_inputs.items():
        batched_tensor_inputs[name] = torch.stack(inp, dim=0)

    out = {
        'inputs': batched_tensor_inputs,
        'targets': torch.FloatTensor(batched_targets),
    }

    return out
コード例 #2
0
ファイル: resnet.py プロジェクト: hgkahng/pysc2-defogging
    def __init__(self,
                 embedding_dims=None,
                 num_classes=2,
                 include=['unit_type', 'player_relative']):
        super(ResNet3D, self).__init__()

        if embedding_dims is None:
            self.embedding_dims = {
                'height_map': 10,
                'visibility_map': 10,
                'player_relative': 10,
                'unit_type': 100
            }
        else:
            assert isinstance(embedding_dims, dict)
            self.embedding_dims = embedding_dims

        self.num_classes = num_classes
        self.include = include
        self.cnn_channel_size = 0
        """Embedding layers."""
        self.embeddings = nn.ModuleDict()
        for name, feat in SPATIAL_FEATURES._asdict().items():
            if name not in self.include:
                continue
            feat_type = str(feat.type).split('.')[-1]
            if feat_type == 'CATEGORICAL':
                self.embeddings[name] = CategoricalEmbedding(
                    category_size=feat.scale,
                    embedding_dim=self.embedding_dims.get(name),
                    name=name,
                )
            elif feat_type == 'SCALAR':
                self.embeddings[name] = ScalarEmbedding(
                    embedding_dim=self.embedding_dims.get(name), name=name)
            else:
                raise NotImplementedError
            self.cnn_channel_size += self.embedding_dims.get(name)

        self.conv1 = nn.Conv3d(self.cnn_channel_size,
                               32,
                               7,
                               stride=(1, 2, 2),
                               padding=(3, 3, 3),
                               bias=False)
        self.bn1 = nn.BatchNorm3d(32)
        self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1)

        self.block1 = BasicBlock3D(32, 64, stride=1, downsample=None)
        self.pool1 = nn.MaxPool3d(kernel_size=(5, 4, 4), padding=0)
        self.block2 = BasicBlock3D(64, 128, stride=2, downsample=None)
        self.pool2 = nn.MaxPool3d(kernel_size=(1, 2, 2), padding=0)
        self.block3 = BasicBlock3D(128, 256, stride=2, downsample=None)
        self.pool3 = nn.AvgPool3d(kernel_size=(5, 4, 4), stride=1, padding=0)

        self.linear = nn.Linear(256, self.num_classes)
コード例 #3
0
        batched_tensor_inputs[name] = torch.stack(inp, dim=0)

    out = {
        'inputs': batched_tensor_inputs,
        'targets': torch.FloatTensor(batched_targets),
    }

    return out


if __name__ == '__main__':

    BATCH_SIZE = 2
    WEIGHTING = 'log'
    MAX_TIMESTEPS = 50
    SPATIAL_SPECS = SPATIAL_FEATURES._asdict()
    ROOT = './parsed/TvP/'

    dataset = SC2ReplayDataset(root_dir=ROOT,
                               train=True,
                               max_timesteps=MAX_TIMESTEPS)
    dataloader = DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        collate_fn=functools.partial(replay_collate_fn,
                                     max_timesteps=MAX_TIMESTEPS,
                                     weighting=WEIGHTING),
    )

    print(f"Number of batches per epoch: {len(dataloader)}")
コード例 #4
0
    def __init__(self,
                 embedding_dim,
                 rnn_input_size,
                 rnn_hidden_size,
                 output_size,
                 include=[
                     'height_map', 'visibility_map', 'player_relative',
                     'unit_type'
                 ]):
        super(SimpleConvLSTM, self).__init__()

        self.embedding_dim = embedding_dim
        self.rnn_input_size = rnn_input_size
        self.rnn_hidden_size = rnn_hidden_size
        self.output_size = output_size
        self.include = include

        self.embedding_dims = {
            'height_map': 10,
            'visibility_map': 10,
            'player_relative': 10,
            'unit_type': 100
        }

        self.cnn_channel_size = 0

        self.feat_names = [k for k in SPATIAL_FEATURES._asdict()]  # redundant?
        """Embedding layers."""
        self.embeddings = nn.ModuleDict()
        for name, feat in SPATIAL_FEATURES._asdict().items():
            if name not in self.include:
                continue
            feat_type = str(feat.type).split('.')[-1]
            if feat_type == 'CATEGORICAL':
                self.embeddings[name] = CategoricalEmbedding(
                    category_size=feat.scale,
                    embedding_dim=self.embedding_dims.get(name),
                    name=name,
                )
            elif feat_type == 'SCALAR':
                self.embeddings[name] = ScalarEmbedding(
                    embedding_dim=self.embedding_dims.get(name), name=name)
            else:
                raise NotImplementedError
            self.cnn_channel_size += self.embedding_dims.get(name)

        # Convolution module.
        self.conv = SimpleConv(
            in_channels=self.cnn_channel_size,
            output_size=self.rnn_input_size,
        )

        # Recurrent module.
        self.gru = SimpleGRU(
            input_size=self.rnn_input_size,
            hidden_size=self.rnn_hidden_size,
            output_size=self.output_size,
        )

        # Attention module.
        self.attn = VaswaniAttention(
            hidden_size=self.rnn_hidden_size,
            context_size=self.rnn_hidden_size,
        )

        self.linear = nn.Linear(self.rnn_hidden_size, self.output_size)