Beispiel #1
0
def create_graph_from_tu_data(graph_data, target, num_node_labels,
                              num_edge_labels, Graph_whole):
    # Graph is the networks graph containing all nodes and edges in the dataset
    nodes = graph_data["graph_nodes"]
    edges = graph_data["graph_edges"]

    G = Graph(target=target)

    for i, node in enumerate(nodes):
        label, attrs = None, None

        if graph_data["node_labels"] != []:
            label = one_hot(graph_data["node_labels"][i], num_node_labels)

        if graph_data["node_attrs"] != []:
            attrs = graph_data["node_attrs"][i]

        G.add_node(node, label=label, attrs=attrs)

    for i, edge in enumerate(edges):
        n1, n2 = edge
        label, attrs = None, None

        if graph_data["edge_labels"] != []:
            label = one_hot(graph_data["edge_labels"][i], num_edge_labels)
        if graph_data["edge_attrs"] != []:
            attrs = graph_data["edge_attrs"][i]

        G.add_edge(n1, n2, label=label, attrs=attrs)

    return G
Beispiel #2
0
def create_graph_from_tu_data(graph_data, target, num_node_labels, num_edge_labels):
    nodes = graph_data["graph_nodes"]
    edges = graph_data["graph_edges"]

    G = Graph(target=target)

    for i, node in enumerate(nodes):
        label, attrs = None, None

        if graph_data["node_labels"] != []:
            label = one_hot(graph_data["node_labels"][i], num_node_labels)

        if graph_data["node_attrs"] != []:
            attrs = graph_data["node_attrs"][i]

        G.add_node(node, label=label, attrs=attrs)

    for i, edge in enumerate(edges):
        n1, n2 = edge
        label, attrs = None, None

        if graph_data["edge_labels"] != []:
            label = one_hot(graph_data["edge_labels"][i], num_edge_labels)
        if graph_data["edge_attrs"] != []:
            attrs = graph_data["edge_attrs"][i]

        G.add_edge(n1, n2, label=label, attrs=attrs)

    return G
def bounded_logit_source_sink(perturbed_logit,
                              clean_class,
                              source_classes,
                              sink_classes,
                              num_classes,
                              confidence=0.0,
                              use_cuda=False):
    one_hot_labels = one_hot(clean_class.cpu(), num_classes=num_classes)
    if use_cuda:
        one_hot_labels = one_hot_labels.cuda()

    loss = torch.tensor([])
    if use_cuda:
        loss = loss.cuda()

    for source_cl, sink_cl in zip(source_classes, sink_classes):
        # Filter all idxs which belong to the source class
        source_cl_idxs = [i == source_cl for i in clean_class]
        source_cl_mask = torch.Tensor(source_cl_idxs) == True
        if torch.sum(source_cl_mask) > 0:
            clean_class_source_cl = clean_class[source_cl_mask]
            one_hot_labels_source_cl = one_hot_labels[source_cl_mask]
            perturbed_logit_source_cl = perturbed_logit[source_cl_mask]

            # source loss: Decrease the Source part
            class_logits_source_cl = (one_hot_labels_source_cl *
                                      perturbed_logit_source_cl).sum(1)
            not_class_logits_source_cl = (
                (1. - one_hot_labels_source_cl) * perturbed_logit_source_cl -
                one_hot_labels_source_cl * 10000.).max(1)[0].detach()
            # source_cl_loss = torch.clamp(class_logits_source_cl - not_class_logits_source_cl, min=-confidence)
            source_cl_loss = torch.clamp(class_logits_source_cl -
                                         not_class_logits_source_cl,
                                         min=0)

            # sink loss: Increase the Sink part
            target_sink_class = torch.ones_like(
                clean_class_source_cl) * sink_cl
            one_hot_labels_sink_cl = one_hot(target_sink_class.cpu(),
                                             num_classes=num_classes)
            if use_cuda:
                one_hot_labels_sink_cl = one_hot_labels_sink_cl.cuda()
            class_logits_sink_cl = (one_hot_labels_sink_cl *
                                    perturbed_logit_source_cl).sum(1)
            not_class_logits_sink_cl = (
                (1. - one_hot_labels_sink_cl) * perturbed_logit_source_cl -
                one_hot_labels_sink_cl * 10000.).max(1)[0].detach()
            sink_cl_loss = torch.clamp(not_class_logits_sink_cl -
                                       class_logits_sink_cl,
                                       min=-confidence)

            loss_source_cl = (source_cl_loss + sink_cl_loss) / 2.

            loss = torch.cat((loss, loss_source_cl), 0)

    assert len(loss) == len(clean_class)  # Can be deleted after a few tries

    return loss
Beispiel #4
0
    def test_on_trainings_set(self):
        print('testing...')
        self.model.eval()
        test_loss = 0
        for i, (data, label) in enumerate(self.train_loader):
            one_hot_matrix = Variable(one_hot(label, self.num_categories))
            if self.args.cuda:
                data = data.cuda()
            data = Variable(data, volatile=True)
            recon_batch, mu, logvar = self.model(data, one_hot_matrix)
            test_loss += self.loss(recon_batch, data, mu, logvar).data[0]
            _, indices = recon_batch.max(1)
            indices.data = indices.data.float() / 255
            if i % 50 == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([
                    data[:n],
                    indices.view(-1, self.args.input_shape.channels, 32,
                                 32)[:n]
                ])
                self.summary_writer.add_image('training_set/image', comparison,
                                              i)

        test_loss /= len(self.test_loader.dataset)
        print('====> Test on training set loss: {:.4f}'.format(test_loss))
        self.model.train()
Beispiel #5
0
    def train(self):
        self.model.train()
        for epoch in range(self.args.start_epoch, self.args.num_epochs):
            loss_list = []
            print("epoch {}...".format(epoch))
            for batch_idx, (data, label) in enumerate(tqdm(self.train_loader)):
                one_hot_matrix = Variable(one_hot(label, self.num_categories))
                if self.args.cuda:
                    data = data.cuda()
                data = Variable(data)
                self.optimizer.zero_grad()
                recon_batch, mu, logvar = self.model(data, one_hot_matrix)
                loss = self.loss(recon_batch, data, mu, logvar)
                loss.backward()
                self.optimizer.step()
                loss_list.append(loss.data)

            print("epoch {}: - loss: {}".format(epoch, np.mean(loss_list)))
            new_lr = self.adjust_learning_rate(epoch)
            print('learning rate:', new_lr)

            self.summary_writer.add_scalar('training/loss',
                                           float(np.mean(loss_list)), epoch)
            self.summary_writer.add_scalar('training/learning_rate', new_lr,
                                           epoch)
            self.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.state_dict(),
                'optimizer': self.optimizer.state_dict(),
            })
            if epoch % self.args.test_every == 0:
                self.test(epoch)
Beispiel #6
0
    def forward(self, input, target):
        one_hot_labels = one_hot(target.cpu(), num_classes=self.num_classes)
        if self.use_cuda:
            one_hot_labels = one_hot_labels.cuda()

        # Get the logit output value
        logits = (one_hot_labels * input).max(1)[0]
        # Increase the logit value
        return torch.mean(-logits)
Beispiel #7
0
def gluonts_batch_to_train_pytorch(batch, device, dtype, dims,
                                   time_features: TimeFeatType):
    n_timesteps = batch["past_target"].shape[1]
    categories = np.expand_dims(batch["feat_static_cat"].asnumpy()[..., 0],
                                axis=0).repeat(n_timesteps, axis=0)
    time_feat = torch.tensor(batch["past_time_feat"].asnumpy(),
                             dtype=torch.float32,
                             device=device).transpose(0, 1)
    assert categories.shape == (n_timesteps, dims.batch)
    categories = (one_hot(labels=torch.tensor(categories).to(torch.int64),
                          num_classes=5).to(dtype).to(device=device))

    seasonal_indicators = (torch.tensor(
        batch["past_seasonal_indicators"][..., 0].asnumpy()).transpose(
            0, 1).to(torch.int64).to(device))
    seasonal_indicator_feat = (one_hot(
        labels=seasonal_indicators[:, :],
        num_classes=7,
        is_squeeze=True,
    ).to(dtype).to(device))

    if time_features.value == TimeFeatType.timefeat.value:
        time_features = time_feat
    elif time_features.value == TimeFeatType.seasonal_indicator.value:
        time_features = seasonal_indicator_feat
    elif time_features.value == TimeFeatType.both.value:
        time_features = torch.cat([time_feat, seasonal_indicator_feat], dim=-1)
    elif time_features.value == TimeFeatType.none.value:
        time_features = None
    else:
        raise ValueError(
            f"unexpected value for time_features: {time_features}")

    y = (torch.tensor(batch["past_target"].asnumpy()).to(dtype).to(
        device=device).transpose(1, 0))

    data = Box(
        y=y,
        seasonal_indicators=seasonal_indicators,
        u_time=time_features,
        u_static_cat=categories,
    )
    return data
Beispiel #8
0
    def forward(self, input, target):
        one_hot_labels = one_hot(target.cpu(), num_classes=self.num_classes)
        if self.use_cuda:
            one_hot_labels = one_hot_labels.cuda()

        target_logits = (one_hot_labels * input).sum(1)
        not_target_logits = ((1. - one_hot_labels) * input -
                             one_hot_labels * 10000.).max(1)[0]
        logit_loss = torch.clamp(not_target_logits - target_logits,
                                 min=-self.confidence)
        return torch.mean(logit_loss)
Beispiel #9
0
    def set_forward_loss(self, x):
        y = torch.from_numpy(np.repeat(range(self.n_way), self.n_query))

        scores = self.set_forward(x)
        if self.loss_type == 'mse':
            y_oh = utils.one_hot(y, self.n_way)
            y_oh = Variable(y_oh.cuda())

            return self.loss_fn(scores, y_oh)
        else:
            y = Variable(y.cuda())
            return self.loss_fn(scores, y)
def logit_inc(perturbed_logit,
              clean_class,
              source_classes,
              sink_classes,
              num_classes=-1,
              confidence=0.0,
              use_cuda=False):
    one_hot_labels = one_hot(clean_class.cpu(), num_classes=num_classes)
    if use_cuda:
        one_hot_labels = one_hot_labels.cuda()
    loss = -(one_hot_labels * perturbed_logit).sum(1)
    return loss
Beispiel #11
0
    def compute_targets(self, rewards, next_states, non_ends, gamma):
        """Compute batch of targets for distributional dqn

        params:
            rewards: Tensor [batch, 1]
            next_states: Tensor [batch, channel, w, h]
            non_ends: Tensor [batch, 1]
            gamma: float
        """
        # get next distribution
        next_states = Variable(next_states, volatile=True)

        # [batch, num_actions], [batch, num_actions, num_atoms]
        next_q_vals, next_probs = self._q_values(self.target_q_net,
                                                 next_states)
        next_actions = next_q_vals.data.max(1, True)[1]  # [batch, 1]
        next_actions = one_hot(next_actions, self.num_actions,
                               device).unsqueeze(2)
        next_greedy_probs = (next_actions * next_probs.data).sum(1)

        # transform the distribution
        rewards = rewards
        non_ends = non_ends
        proj_zpoints = rewards + gamma * non_ends * self.zpoints.data
        proj_zpoints.clamp_(self.vmin, self.vmax)

        # project onto shared support
        b = (proj_zpoints - self.vmin) / self.delta_z
        lower = b.floor()
        upper = b.ceil()
        # handle corner case where b is integer
        eq = (upper == lower).float()
        lower -= eq
        lt0 = (lower < 0).float()
        lower += lt0
        upper += lt0

        # note: it's faster to do the following on cpu
        ml = (next_greedy_probs * (upper - b)).cpu().numpy()
        mu = (next_greedy_probs * (b - lower)).cpu().numpy()

        lower = lower.cpu().numpy().astype(np.int32)
        upper = upper.cpu().numpy().astype(np.int32)

        batch_size = rewards.size(0)
        mass = np.zeros((batch_size, self.num_atoms), dtype=np.float32)
        brange = range(batch_size)
        for i in range(self.num_atoms):
            mass[brange, lower[brange, i]] += ml[brange, i]
            mass[brange, upper[brange, i]] += mu[brange, i]

        return torch.from_numpy(mass).to(device)
Beispiel #12
0
    def set_forward(self, x, is_feature=False):
        z_support, z_query = self.parse_feature(x, is_feature)

        z_support = z_support.contiguous().view(self.n_way * self.n_support,
                                                -1)
        z_query = z_query.contiguous().view(self.n_way * self.n_query, -1)
        G, G_normalized = self.encode_training_set(z_support)

        y_s = torch.from_numpy(np.repeat(range(self.n_way), self.n_support))
        Y_S = Variable(utils.one_hot(y_s, self.n_way)).cuda()
        f = z_query
        logprobs = self.get_logprobs(f, G, G_normalized, Y_S)
        return logprobs
def bounded_logit_inc_sink(perturbed_logit,
                           clean_class,
                           source_classes,
                           sink_classes,
                           num_classes,
                           confidence=0.0,
                           use_cuda=False):
    one_hot_labels = one_hot(clean_class.cpu(), num_classes=num_classes)
    if use_cuda:
        one_hot_labels = one_hot_labels.cuda()
    class_logits = (one_hot_labels * perturbed_logit).sum(1)
    not_class_logits = ((1. - one_hot_labels) * perturbed_logit -
                        one_hot_labels * 10000.).max(1)[0]
    loss = torch.clamp(not_class_logits - class_logits, min=-confidence)
    return loss
Beispiel #14
0
    def __getitem__(self, index):
        if index < len(self.image_list):
            image_path = os.path.join(self.image_dir, self.image_list[index])
            img_x = np.array(Image.open(image_path))[:, :, 0:3]

            label_path = os.path.join(self.label_dir, self.label_list[index])

            if self.dataset == "Cityscape":
                target = np.array(Image.open(label_path).convert('L'),
                                  dtype='long')
                target[target == -1] = 34
            elif self.dataset == "Carla":
                target = np.array(Image.open(label_path), dtype='long')[:, :,
                                                                        0]
            elif self.dataset == "CGMU":
                target = np.array(Image.open(label_path), dtype='long')[:, :,
                                                                        0]

            if self.labels2keep:
                target = self.targetTransform(target)

            initial_size = img_x.shape[0:2]

            img_x = self.default_transform(img_x)
            target = one_hot(target, self.num_classes, self.device)

            if self.fixed_size:
                img_x, target = self.functional_fixed_size((img_x, target))

            if self.crop:
                img_x, target = self.initialCropPad(img_x, initial_size,
                                                    target)

            if self.transforms is not None:
                img_x, target = self.transforms((img_x, target))

            sample = {
                'image': img_x,
                'target': target,
                'initial_size': initial_size
            }

        else:
            print('Index out of Data Range')

        return sample
Beispiel #15
0
 def forward(
     self,
     feat_static_cat: torch.Tensor,
     time_feat: torch.Tensor,
     seasonal_indicators: Optional[torch.Tensor] = None,
 ) -> (ControlInputsSGLS, ControlInputsSGLSISSM):
     feat_static_onehot_repeat = self._repeat_static_feats(
         feat_static=one_hot(
             feat_static_cat,
             num_classes=self.num_classes,
         ).to(dtype=feat_static_cat.dtype),
         n_timesteps=len(time_feat),
     )
     ctrl_features = self.mlp(
         torch.cat((feat_static_onehot_repeat, time_feat), dim=-1), )
     return self._all_same_controls(
         ctrl_features=ctrl_features,
         seasonal_indicators=seasonal_indicators,
     )
Beispiel #16
0
def test_fusion_manually(distributions, threshold=1e-4):
    fused_dist = ProbabilisticSensorFusion()(distributions)
    if type(distributions[0]) in [Categorical]:
        test_points = torch.randint(
            low=0,
            high=distributions[0].probs.shape[-1],
            size=(100, ) + tuple(fused_dist.batch_shape) +
            tuple(fused_dist.event_shape),
        )
    elif type(distributions[0]) in [OneHotCategorical]:
        test_points = one_hot(
            torch.randint(
                low=0,
                high=distributions[0].probs.shape[-1] - 1,
                size=(100, ) + tuple(fused_dist.batch_shape) +
                tuple(fused_dist.event_shape),
            ),
            num_classes=distributions[0].probs.shape[-1],
        )
    elif type(distributions[0]) in [Bernoulli]:
        test_points = torch.randint(
            low=0,
            high=2,
            size=(100, ) + tuple(fused_dist.batch_shape) +
            tuple(fused_dist.event_shape),
        )
        # what the f**k pytorch... float Bernoulli but integer categoricals?!
        test_points = test_points.to(distributions[0].logits.dtype)
    else:
        test_points = torch.randn((100, ) + tuple(fused_dist.batch_shape) +
                                  tuple(fused_dist.event_shape))

    log_prob_fused = fused_dist.log_prob(test_points)
    log_probs = [dist.log_prob(test_points) for dist in distributions]
    manually_fused_unnormalized = torch.stack(log_probs, dim=-1).sum(dim=-1)

    diff_log_prob_fused = log_prob_fused[1:] - log_prob_fused[:-1]
    diff_log_prob_manual = (manually_fused_unnormalized[1:] -
                            manually_fused_unnormalized[:-1])
    assert torch.all((diff_log_prob_fused - diff_log_prob_manual) < threshold)
    return True
Beispiel #17
0
    def learn(self, experiences, gamma):
        """Update value parameters using given batch of experience tuples.

        Params
        ======
            experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples
            gamma (float): discount factor
        """
        states, actions, rewards, next_states, non_ends = experiences

        actions = one_hot(actions, self.action_size, device)
        targets = self.compute_targets(rewards, next_states, non_ends, gamma)
        states = Variable(states)
        actions = Variable(actions)
        targets = Variable(targets)
        loss = self.loss(states, actions, targets)
        loss.backward()
        self.optimizer.step()
        self.optimizer.zero_grad()
        #return loss.data[0]

        # ------------------- update target network ------------------- #
        self.soft_update(self.online_q_net, self.target_q_net, TAU)
Beispiel #18
0
    def set_forward_adaptation(self,
                               x,
                               is_feature=True):  # overwrite parent function
        assert is_feature == True, 'Finetune only support fixed feature'
        full_n_support = self.n_support
        full_n_query = self.n_query
        relation_module_clone = RelationModule(self.feat_dim, 8,
                                               self.loss_type)
        relation_module_clone.load_state_dict(
            self.relation_module.state_dict())

        z_support, z_query = self.parse_feature(x, is_feature)
        z_support = z_support.contiguous()
        set_optimizer = torch.optim.SGD(self.relation_module.parameters(),
                                        lr=0.01,
                                        momentum=0.9,
                                        dampening=0.9,
                                        weight_decay=0.001)

        self.n_support = 3
        self.n_query = 2

        z_support_cpu = z_support.data.cpu().numpy()
        for epoch in range(100):
            perm_id = np.random.permutation(full_n_support).tolist()
            sub_x = np.array([
                z_support_cpu[i, perm_id, :, :, :]
                for i in range(z_support.size(0))
            ])
            sub_x = torch.Tensor(sub_x).cuda()
            if self.change_way:
                self.n_way = sub_x.size(0)
            set_optimizer.zero_grad()
            y = torch.from_numpy(np.repeat(range(self.n_way), self.n_query))
            scores = self.set_forward(sub_x, is_feature=True)
            if self.loss_type == 'mse':
                y_oh = utils.one_hot(y, self.n_way)
                y_oh = Variable(y_oh.cuda())

                loss = self.loss_fn(scores, y_oh)
            else:
                y = Variable(y.cuda())
                loss = self.loss_fn(scores, y)
            loss.backward()
            set_optimizer.step()

        self.n_support = full_n_support
        self.n_query = full_n_query
        z_proto = z_support.view(self.n_way, self.n_support,
                                 *self.feat_dim).mean(1)
        z_query = z_query.contiguous().view(self.n_way * self.n_query,
                                            *self.feat_dim)

        z_proto_ext = z_proto.unsqueeze(0).repeat(self.n_query * self.n_way, 1,
                                                  1, 1, 1)
        z_query_ext = z_query.unsqueeze(0).repeat(self.n_way, 1, 1, 1, 1)
        z_query_ext = torch.transpose(z_query_ext, 0, 1)
        extend_final_feat_dim = self.feat_dim.copy()
        extend_final_feat_dim[0] *= 2
        relation_pairs = torch.cat((z_proto_ext, z_query_ext),
                                   2).view(-1, *extend_final_feat_dim)
        relations = self.relation_module(relation_pairs).view(-1, self.n_way)

        self.relation_module.load_state_dict(
            relation_module_clone.state_dict())
        return relations
Beispiel #19
0
def get_data(u_features, v_features, node_mode, adj_train, train_labels,
             train_u_indices, train_v_indices, val_labels, val_u_indices,
             val_v_indices, test_labels, test_u_indices, test_v_indices,
             class_values, one_hot_edge, soft_one_hot_edge, norm_label,
             ce_loss):
    train_indices = train_labels
    val_indices = val_labels
    test_indices = test_labels
    # indices are the index in class_values
    train_labels = torch.FloatTensor(class_values[train_labels])
    val_labels = torch.FloatTensor(class_values[val_labels])
    test_labels = torch.FloatTensor(class_values[test_labels])
    if norm_label:
        train_labels = train_labels / max(class_values)
        val_labels = val_labels / max(class_values)
        test_labels = test_labels / max(class_values)

    n_row, n_col = adj_train.shape
    if (u_features is None) or (v_features is None):
        x = torch.tensor(create_node(adj_train, node_mode), dtype=torch.float)
    else:
        print("using given feature")
        x = torch.zeros((u_features.shape[0] + v_features.shape[0],
                         np.maximum(u_features.shape[1], v_features.shape[1])),
                        dtype=torch.float)
        x[:u_features.shape[0], :u_features.shape[1]] = torch.tensor(
            u_features, dtype=torch.float)
        x[u_features.shape[0]:, :v_features.shape[1]] = torch.tensor(
            v_features, dtype=torch.float)

    train_v_indices = train_v_indices + n_row
    train_edge_index = torch.tensor([
        np.append(train_u_indices, train_v_indices),
        np.append(train_v_indices, train_u_indices)
    ],
                                    dtype=int)
    if one_hot_edge:
        train_edge_attr = one_hot(train_indices, len(class_values))
        train_edge_attr = torch.cat((train_edge_attr, train_edge_attr), 0)
    elif soft_one_hot_edge:
        train_edge_attr = soft_one_hot(train_indices, len(class_values))
        train_edge_attr = torch.cat((train_edge_attr, train_edge_attr), 0)
    else:
        train_edge_attr = torch.tensor(np.append(train_labels,
                                                 train_labels)[:, None],
                                       dtype=torch.float)

    val_v_indices = val_v_indices + n_row
    val_edge_index = torch.tensor([
        np.append(val_u_indices, val_v_indices),
        np.append(val_v_indices, val_u_indices)
    ],
                                  dtype=int)
    if one_hot_edge:
        val_edge_attr = one_hot(val_indices, len(class_values))
        val_edge_attr = torch.cat((val_edge_attr, val_edge_attr), 0)
    elif soft_one_hot_edge:
        val_edge_attr = soft_one_hot(val_indices, len(class_values))
        val_edge_attr = torch.cat((val_edge_attr, val_edge_attr), 0)
    else:
        val_edge_attr = torch.tensor(np.append(val_labels, val_labels)[:,
                                                                       None],
                                     dtype=torch.float)

    test_v_indices = test_v_indices + n_row
    test_edge_index = torch.tensor([
        np.append(test_u_indices, test_v_indices),
        np.append(test_v_indices, test_u_indices)
    ],
                                   dtype=int)
    if one_hot_edge:
        test_edge_attr = one_hot(test_indices, len(class_values))
        test_edge_attr = torch.cat((test_edge_attr, test_edge_attr), 0)
    elif soft_one_hot_edge:
        test_edge_attr = soft_one_hot(test_indices, len(class_values))
        test_edge_attr = torch.cat((test_edge_attr, test_edge_attr), 0)
    else:
        test_edge_attr = torch.tensor(np.append(test_labels,
                                                test_labels)[:, None],
                                      dtype=torch.float)

    if ce_loss:
        train_labels = torch.tensor(train_indices, dtype=int)
        val_labels = torch.tensor(val_indices, dtype=int)
        test_labels = torch.tensor(test_indices, dtype=int)

    data = Data(x=x,
                train_edge_index=train_edge_index,
                train_edge_attr=train_edge_attr,
                train_labels=train_labels,
                val_edge_index=val_edge_index,
                val_edge_attr=val_edge_attr,
                val_labels=val_labels,
                test_edge_index=test_edge_index,
                test_edge_attr=test_edge_attr,
                test_labels=test_labels,
                edge_attr_dim=train_edge_attr.shape[-1],
                class_values=torch.FloatTensor(class_values),
                user_num=adj_train.shape[0])
    return data
Beispiel #20
0
 def C(self, seasonal_indicators):
     C = one_hot(seasonal_indicators, num_classes=self.n_state,)[
         ..., None, :
     ]
     return C.to(self.dtype).to(self.device)
Beispiel #21
0
 def R_diag_projector(self, seasonal_indicators):
     R_diag_projector = one_hot(
         seasonal_indicators, num_classes=self.n_state,
     )[..., None]
     return R_diag_projector.to(self.dtype).to(self.device)
def transform_gluonts_to_pytorch(
    batch,
    time_features: TimeFeatType,
    cardinalities_feat_static_cat,
    cardinalities_season_indicators,
    bias_y=0.0,
    factor_y=1.0,
    device="cuda",
    dtype=torch.float32,
):
    n_timesteps_targets = batch["past_target"].shape[1]
    n_timesteps_inputs = batch["past_time_feat"].shape[1] + (
        batch["future_time_feat"].shape[1]
        if "future_time_feat" in batch
        else 0
    )

    # Category features
    feat_static_cat = batch["feat_static_cat"].asnumpy()
    feat_static_cat = [
        torch.tensor(feat_static_cat[..., idx], dtype=torch.int64)
        for idx, num_classes in enumerate(cardinalities_feat_static_cat)
    ]
    feat_static_cat = [
        feat[None, ...].repeat((n_timesteps_inputs, 1,))
        for feat in feat_static_cat
    ]
    feat_static_cat = [feat.to(device) for feat in feat_static_cat]

    # Season indicator features
    seasonal_indicators = np.concatenate(
        [batch["past_seasonal_indicators"].asnumpy().transpose(1, 0, 2)]
        + (
            [batch["future_seasonal_indicators"].asnumpy().transpose(1, 0, 2)]
            if "future_seasonal_indicators" in batch
            else []
        ),
        axis=0,
    )
    seasonal_indicators = torch.tensor(
        seasonal_indicators, dtype=torch.int64, device=device
    )
    seasonal_indicator_feat = [
        one_hot(
            labels=seasonal_indicators[:, :, idx],
            num_classes=num_classes,
            is_squeeze=True,
        )
        .to(dtype)
        .to(device)
        for idx, num_classes in enumerate(cardinalities_season_indicators)
    ]
    seasonal_indicator_feat = torch.cat(seasonal_indicator_feat, dim=-1)

    # Time-Features
    time_feat = torch.tensor(
        np.concatenate(
            [batch["past_time_feat"].asnumpy().transpose(1, 0, 2)]
            + (
                [batch["future_time_feat"].asnumpy().transpose(1, 0, 2)]
                if "future_time_feat" in batch
                else []
            ),
            axis=0,
        ),
        dtype=dtype,
        device=device,
    )

    if time_features.value == TimeFeatType.timefeat.value:
        time_features = time_feat
    elif time_features.value == TimeFeatType.seasonal_indicator.value:
        time_features = seasonal_indicator_feat
    elif time_features.value == TimeFeatType.both.value:
        time_features = torch.cat([seasonal_indicator_feat, time_feat], dim=-1)
    elif time_features.value == TimeFeatType.none.value:
        time_features = None
    else:
        raise ValueError(
            f"unexpected value for time_features: {time_features}"
        )

    y = (
        torch.tensor(batch["past_target"].asnumpy())
        .to(dtype)
        .to(device=device)
        .transpose(1, 0)
    )

    # for these experiments we have just 1 type of static feats.
    assert len(feat_static_cat) == 1
    data = Box(
        y=(y - bias_y) / factor_y,
        seasonal_indicators=seasonal_indicators,
        u_time=time_features,
        u_static_cat=feat_static_cat[0],
    )
    return data
Beispiel #23
0
def train():
    # create place holder img
    input_ph, ground_truths_ph, ground_truths, pre_processed_input = dh.get_place_holders(
    )
    # Processing LabelId's
    one_hot_labels = utils.one_hot(
        ground_truths[0],
        is_color=False)  # TODO: add dictionary task-to-label-number
    # Geting model
    autoencoder = utils.get_autoencoder(user_config.autoencoder,
                                        config.working_dataset, config.strided)

    logits = autoencoder.inference(pre_processed_input)

    processed_ground_truths = [
        one_hot_labels, ground_truths[1], ground_truths[2]
    ]
    loss_op, loss_list, multi_loss_class = lh.get_loss(
        logits, processed_ground_truths)

    optimizer = tf.train.AdamOptimizer(FLAGS.leaning_rate)
    train_step = optimizer.minimize(loss_op)

    saver = tf.train.Saver()
    gpu_options = tf.GPUOptions(
        per_process_gpu_memory_fraction=config.gpu_memory_fraction)
    session_config = tf.ConfigProto(allow_soft_placement=True,
                                    gpu_options=gpu_options)
    if FLAGS.use_summary:
        sh.define_summaries(logits, ground_truths, processed_ground_truths,
                            loss_op, loss_list, multi_loss_class)
    num_of_train_examples = FLAGS.num_of_train_imgs
    statistic = Statistic(logits, loss_op, loss_list, input_ph,
                          ground_truths_ph, multi_loss_class,
                          processed_ground_truths)
    val_input_img, val_gt = dh.init_data(FLAGS.num_of_val_imgs)

    for ind in range(FLAGS.num_of_val_imgs):
        val_input_img[ind], val_gt[ind] = dh.get_data(ind, 'val')
    with tf.Session(config=session_config) as sess:
        global_step = start_training(sess, autoencoder, saver)
        summary = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter(results_dir + '/logs',
                                               sess.graph)
        input_img, gt = dh.init_data(num_of_train_examples)
        input_batch = None
        # training starts here
        step = 0
        for epoch in range(FLAGS.num_of_epchs):
            print("\nEpoch: " + str(epoch))
            sub_batche = 0
            for ind in tqdm(np.random.permutation(num_of_train_examples)):
                if input_img[ind] is None:
                    input_img[ind], gt[ind] = dh.get_data(ind, 'train')
                # ----- make the random batch ----
                if sub_batche == 0:
                    input_batch = input_img[ind]
                    gt_batch = gt[ind]
                else:
                    input_batch, gt_batch = add_to_batch(
                        input_batch, gt_batch, input_img[ind], gt[ind])
                if sub_batche < FLAGS.batch - 1:
                    sub_batche += 1
                    continue
                sub_batche = 0
                # ---- batch is ready ----
                feed_dict = get_feed_dict(input_ph, ground_truths_ph,
                                          input_batch, gt_batch)
                sess.run(train_step, feed_dict=feed_dict)
                if FLAGS.use_summary and step % FLAGS.calc_summary == 0:
                    sh.handle_summarys(sess, logits, summary, summary_writer,
                                       step, feed_dict)
                step += 1
            statistic.handle_statistic(epoch, logits, sess, input_img, gt,
                                       val_input_img, val_gt)
            if epoch % FLAGS.epoch_model_ckpts == 0:
                ckpt_dir = os.path.join(results_dir, 'global_ckpt')
                if not os.path.exists(ckpt_dir):
                    os.mkdir(ckpt_dir)
                saver.save(sess,
                           os.path.join(ckpt_dir, 'global_ckpt'),
                           global_step=global_step)
            if epoch % FLAGS.epoch_analysis_breakpoints == 0:
                analysis_ckpt_dir = os.path.join(results_dir, 'Analysis_ckpts')
                if not os.path.exists(analysis_ckpt_dir):
                    os.mkdir(analysis_ckpt_dir)
                saver.save(sess,
                           os.path.join(analysis_ckpt_dir,
                                        'epoch_' + str(epoch)),
                           global_step=global_step)
def runTraining(args):
    print(f">>> Setting up to train on {args.dataset} with {args.mode}")
    net, optimizer, device, train_loader, val_loader = setup(args)

    ce_loss = CrossEntropy(idk=[0,
                                1])  # Supervise both background and foreground
    partial_ce = PartialCrossEntropy()  # Supervise only foregroundz
    sizeLoss = NaiveSizeLoss()

    for i in range(args.epochs):
        net.train()

        log_ce = torch.zeros((len(train_loader)), device=device)
        log_sizeloss = torch.zeros((len(train_loader)), device=device)
        log_sizediff = torch.zeros((len(train_loader)), device=device)
        log_dice = torch.zeros((len(train_loader)), device=device)

        desc = f">> Training   ({i: 4d})"
        tq_iter = tqdm_(enumerate(train_loader),
                        total=len(train_loader),
                        desc=desc)
        for j, data in tq_iter:
            img = data["img"].to(device)
            full_mask = data["full_mask"].to(device)
            weak_mask = data["weak_mask"].to(device)

            bounds = data["bounds"].to(device)

            optimizer.zero_grad()

            # Sanity tests to see we loaded and encoded the data correctly
            assert 0 <= img.min() and img.max() <= 1
            B, _, W, H = img.shape
            assert B == 1  # Since we log the values in a simple way, doesn't handle more
            assert weak_mask.shape == (B, 2, W, H)
            assert one_hot(weak_mask), one_hot(weak_mask)

            logits = net(img)
            segment_prob = F.softmax(5 * logits, dim=1)
            segment_oh = probs2one_hot(segment_prob)

            pred_size = einsum("bkwh->bk", segment_oh)[:, 1]
            log_sizediff[j] = pred_size - data["true_size"][0, 1]
            log_dice[j] = dice_coef(segment_oh,
                                    full_mask)[0, 1]  # 1st item, 2nd class

            if args.mode == 'full':
                lossEpoch = ce_loss(segment_prob, full_mask)
                log_ce[j] = lossEpoch.item()

                log_sizeloss[j] = 0
            elif args.mode == 'unconstrained':
                ce_val = partial_ce(segment_prob, weak_mask)
                log_ce[j] = ce_val.item()
                lossEpoch = ce_val

                log_sizeloss[j] = 0
            else:
                ce_val = partial_ce(segment_prob, weak_mask)
                log_ce[j] = ce_val.item()

                sizeLoss_val = sizeLoss(segment_prob, bounds)
                log_sizeloss[j] = sizeLoss_val.item()

                lossEpoch = ce_val + sizeLoss_val

            lossEpoch.backward()
            optimizer.step()

            tq_iter.set_postfix({
                "DSC":
                f"{log_dice[:j+1].mean():05.3f}",
                "SizeDiff":
                f"{log_sizediff[:j+1].mean():07.1f}",
                "LossCE":
                f"{log_ce[:j+1].mean():5.2e}",
                **({
                    "LossSize": f"{log_sizeloss[:j+1].mean():5.2e}"
                } if args.mode == 'constrained' else {})
            })
            tq_iter.update(1)
        tq_iter.close()

        if (i % 5) == 0:
            saveImages(net, val_loader, 1, i, args.dataset, args.mode, device)