def replay_collate_fn(samples, max_timesteps=50, weighting='log'): """ Arguments: samples: a list of tuples (x: dict, y: int) max_timesteps: The maximum number of timesteps to allow. Frames will be sampled with normalized exponential weights against time. weighting: str, one of 'exp' or 'log'. """ if weighting == 'exp': weight_fn = np.exp elif weighting == 'log': weight_fn = lambda x: np.log(1 + x) else: raise NotImplementedError batched_inputs = collections.defaultdict(list) batched_targets = [] for input_dict, target in samples: timesteps, _, _ = input_dict.get('unit_type').shape if timesteps < max_timesteps: max_timesteps = timesteps weights = [weight_fn(i) for i in range(timesteps)] weights /= np.sum(weights) try: timestep_indices = np.random.choice(timesteps, max_timesteps, replace=False, p=weights) #timestep_indices = WeightedRandomSampler(weights, max_timesteps, replacement=False) timestep_indices = sorted(list(timestep_indices), reverse=False) except ValueError as e: print(f"Timesteps: {timesteps}") print(f"Max timesteps: {max_timesteps}") print(f"Weights: {len(weights)}") raise ValueError(str(e)) for name, feature in input_dict.items(): sampled_feature = feature[timestep_indices] if name == 'unit_type': mask = sampled_feature > SPATIAL_FEATURES._asdict()['unit_type'].scale + 1 sampled_feature[mask] = 0 #if name != 'height_map': # sampled_feature = sampled_feature.astype(np.int32) sampled_feature = torch.from_numpy(sampled_feature) batched_inputs[name].append(sampled_feature) batched_targets.append(target) batched_tensor_inputs = {} for name, inp in batched_inputs.items(): batched_tensor_inputs[name] = torch.stack(inp, dim=0) out = { 'inputs': batched_tensor_inputs, 'targets': torch.FloatTensor(batched_targets), } return out
def __init__(self, embedding_dims=None, num_classes=2, include=['unit_type', 'player_relative']): super(ResNet3D, self).__init__() if embedding_dims is None: self.embedding_dims = { 'height_map': 10, 'visibility_map': 10, 'player_relative': 10, 'unit_type': 100 } else: assert isinstance(embedding_dims, dict) self.embedding_dims = embedding_dims self.num_classes = num_classes self.include = include self.cnn_channel_size = 0 """Embedding layers.""" self.embeddings = nn.ModuleDict() for name, feat in SPATIAL_FEATURES._asdict().items(): if name not in self.include: continue feat_type = str(feat.type).split('.')[-1] if feat_type == 'CATEGORICAL': self.embeddings[name] = CategoricalEmbedding( category_size=feat.scale, embedding_dim=self.embedding_dims.get(name), name=name, ) elif feat_type == 'SCALAR': self.embeddings[name] = ScalarEmbedding( embedding_dim=self.embedding_dims.get(name), name=name) else: raise NotImplementedError self.cnn_channel_size += self.embedding_dims.get(name) self.conv1 = nn.Conv3d(self.cnn_channel_size, 32, 7, stride=(1, 2, 2), padding=(3, 3, 3), bias=False) self.bn1 = nn.BatchNorm3d(32) self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1) self.block1 = BasicBlock3D(32, 64, stride=1, downsample=None) self.pool1 = nn.MaxPool3d(kernel_size=(5, 4, 4), padding=0) self.block2 = BasicBlock3D(64, 128, stride=2, downsample=None) self.pool2 = nn.MaxPool3d(kernel_size=(1, 2, 2), padding=0) self.block3 = BasicBlock3D(128, 256, stride=2, downsample=None) self.pool3 = nn.AvgPool3d(kernel_size=(5, 4, 4), stride=1, padding=0) self.linear = nn.Linear(256, self.num_classes)
batched_tensor_inputs[name] = torch.stack(inp, dim=0) out = { 'inputs': batched_tensor_inputs, 'targets': torch.FloatTensor(batched_targets), } return out if __name__ == '__main__': BATCH_SIZE = 2 WEIGHTING = 'log' MAX_TIMESTEPS = 50 SPATIAL_SPECS = SPATIAL_FEATURES._asdict() ROOT = './parsed/TvP/' dataset = SC2ReplayDataset(root_dir=ROOT, train=True, max_timesteps=MAX_TIMESTEPS) dataloader = DataLoader( dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=functools.partial(replay_collate_fn, max_timesteps=MAX_TIMESTEPS, weighting=WEIGHTING), ) print(f"Number of batches per epoch: {len(dataloader)}")
def __init__(self, embedding_dim, rnn_input_size, rnn_hidden_size, output_size, include=[ 'height_map', 'visibility_map', 'player_relative', 'unit_type' ]): super(SimpleConvLSTM, self).__init__() self.embedding_dim = embedding_dim self.rnn_input_size = rnn_input_size self.rnn_hidden_size = rnn_hidden_size self.output_size = output_size self.include = include self.embedding_dims = { 'height_map': 10, 'visibility_map': 10, 'player_relative': 10, 'unit_type': 100 } self.cnn_channel_size = 0 self.feat_names = [k for k in SPATIAL_FEATURES._asdict()] # redundant? """Embedding layers.""" self.embeddings = nn.ModuleDict() for name, feat in SPATIAL_FEATURES._asdict().items(): if name not in self.include: continue feat_type = str(feat.type).split('.')[-1] if feat_type == 'CATEGORICAL': self.embeddings[name] = CategoricalEmbedding( category_size=feat.scale, embedding_dim=self.embedding_dims.get(name), name=name, ) elif feat_type == 'SCALAR': self.embeddings[name] = ScalarEmbedding( embedding_dim=self.embedding_dims.get(name), name=name) else: raise NotImplementedError self.cnn_channel_size += self.embedding_dims.get(name) # Convolution module. self.conv = SimpleConv( in_channels=self.cnn_channel_size, output_size=self.rnn_input_size, ) # Recurrent module. self.gru = SimpleGRU( input_size=self.rnn_input_size, hidden_size=self.rnn_hidden_size, output_size=self.output_size, ) # Attention module. self.attn = VaswaniAttention( hidden_size=self.rnn_hidden_size, context_size=self.rnn_hidden_size, ) self.linear = nn.Linear(self.rnn_hidden_size, self.output_size)