Esempio n. 1
0
    def attention(self, x, adj):
        # x's shape: (nodes, n_heads, out_size)
        att = copy.deepcopy(adj)
        # e's shape: (nodes, n_heads, 2)
        e = F.squeeze(F.matmul(x[:, :, None, :], self.attention_W), 2)
        # e_row's shape: (nodes, n_heads, 1)
        e_row, e_col = F.split_axis(e, 2, axis=2)
        # Linear(concat(x_row, x_col)) = Linear_row(x_row) + Linear_col(x_col)
        h = F.squeeze(e_row[adj.row] + e_col[adj.col], 2)
        att_data = F.leaky_relu(h, 0.2)
        # Scaling trick for numerical stability
        att_data -= self.xp.max(att_data.data)
        att_data = F.exp(att_data)
        x = F.dropout(x, 0.6)

        output = []
        for att_data_i, xi in zip(F.split_axis(att_data, self.n_heads, axis=1),
                                  F.split_axis(x, self.n_heads, axis=1)):
            att.data = F.squeeze(att_data_i, 1)
            rowsum = F.sparse_matmul(
                att, self.xp.ones([att.shape[1], 1], dtype=att.data.dtype))
            rowsum = 1. / F.squeeze(rowsum, 1)
            # We could've just converted rowsum to diagonal matrix and do sparse_matmul
            # but current sparse_matmul does not support two sparse matrix inputs
            att.data = att.data * rowsum[att.row]
            att.data = F.dropout(att.data, 0.6)
            output.append(F.sparse_matmul(att, F.squeeze(xi, 1)))
        output = F.concat(output, axis=1)
        return output
Esempio n. 2
0
 def test_invalid_inputs(self):
     a = _setup_tensor(.5, 1, (1, 3, 3), numpy.float32, .75)
     b = _setup_tensor(.5, 1, (1, 3, 3), numpy.float32, .75)
     sp_a = utils.to_coo(a)
     sp_b = utils.to_coo(b)
     with self.assertRaises(ValueError):
         F.sparse_matmul(sp_a, sp_b, self.transa, self.transb)
     with self.assertRaises(ValueError):
         F.sparse_matmul(a, b, self.transa, self.transb)
Esempio n. 3
0
 def test_invalid_shape(self):
     a = _setup_tensor(.5, 1, (1, 2, 3), numpy.float32, .75)
     b = _setup_tensor(.5, 1, (1, 4, 5), numpy.float32, .75)
     sp_a = utils.to_coo(a)
     sp_b = utils.to_coo(b)
     with self.assertRaises(type_check.InvalidType):
         F.sparse_matmul(sp_a, b, self.transa, self.transb)
     with self.assertRaises(type_check.InvalidType):
         F.sparse_matmul(a, sp_b, self.transa, self.transb)
 def test_invalid_inputs(self):
     a = _setup_tensor(.5, 1, (1, 3, 3), numpy.float32, .75)
     b = _setup_tensor(.5, 1, (1, 3, 3), numpy.float32, .75)
     sp_a = utils.to_coo(a)
     sp_b = utils.to_coo(b)
     with self.assertRaises(ValueError):
         F.sparse_matmul(sp_a, sp_b, self.transa, self.transb)
     with self.assertRaises(ValueError):
         F.sparse_matmul(a, b, self.transa, self.transb)
 def test_invalid_shape(self):
     a = _setup_tensor(.5, 1, (1, 2, 3), numpy.float32, .75)
     b = _setup_tensor(.5, 1, (1, 4, 5), numpy.float32, .75)
     sp_a = utils.to_coo(a)
     sp_b = utils.to_coo(b)
     with self.assertRaises(type_check.InvalidType):
         F.sparse_matmul(sp_a, b, self.transa, self.transb)
     with self.assertRaises(type_check.InvalidType):
         F.sparse_matmul(a, sp_b, self.transa, self.transb)
Esempio n. 6
0
    def __call__(self, x, adj):
        if isinstance(x, chainer.utils.CooMatrix):
            x = F.sparse_matmul(x, self.W)
        else:
            x = F.matmul(x, self.W)
        output = F.sparse_matmul(adj, x)

        if self.b is not None:
            output += self.b

        return output
Esempio n. 7
0
    def __call__(self, x, adj):
        support = F.matmul(x, self.W)
        output = F.sparse_matmul(adj, support)

        if self.b is not None:
            output += self.b

        return output
    def forward(self, gs: GraphsData, x: VariableOrArray) -> Variable:
        x_new = x

        FI_fc1a_x = self.FI.fc1a(x)
        FI_fc1b_x = self.FI.fc1b(x)
        FO_fc1a_x = self.FO.fc1a(x)
        FO_fc1b_x = self.FO.fc1b(x)
        FI_inputs = FI_fc1a_x[gs.edges[:, 0]] + FI_fc1b_x[gs.edges[:, 1]]
        FO_inputs = FO_fc1a_x[gs.edges[:, 0]] + FO_fc1b_x[gs.edges[:, 1]]

        FI_outputs = self.FI(FI_inputs)
        FO_outputs = self.FO(FO_inputs)

        d = F.sparse_matmul(gs.MI, FI_outputs) + \
            F.sparse_matmul(gs.MO, FO_outputs)

        x_new += d

        if self._order_preserving:
            FL_fc1a_x = self.FL.fc1a(x)
            FL_fc1b_x = self.FL.fc1b(x)
            FL_fc1c_x = self.FL.fc1c(x)
            FH_fc1a_x = self.FH.fc1a(x)
            FH_fc1b_x = self.FH.fc1b(x)
            FH_fc1c_x = self.FH.fc1c(x)
            FR_fc1a_x = self.FR.fc1a(x)
            FR_fc1b_x = self.FR.fc1b(x)
            FR_fc1c_x = self.FR.fc1c(x)
            FL_inputs = FL_fc1a_x[gs.treelets[:, 0]] + FL_fc1b_x[
                gs.treelets[:, 1]] + FL_fc1c_x[gs.treelets[:, 2]]
            FH_inputs = FH_fc1a_x[gs.treelets[:, 0]] + FH_fc1b_x[
                gs.treelets[:, 1]] + FH_fc1c_x[gs.treelets[:, 2]]
            FR_inputs = FR_fc1a_x[gs.treelets[:, 0]] + FR_fc1b_x[
                gs.treelets[:, 1]] + FR_fc1c_x[gs.treelets[:, 2]]

            FL_outputs = self.FL(FL_inputs)
            FH_outputs = self.FH(FH_inputs)
            FR_outputs = self.FR(FR_inputs)

            d = F.sparse_matmul(gs.ML, FL_outputs) + \
                F.sparse_matmul(gs.MH, FH_outputs) + \
                F.sparse_matmul(gs.MR, FR_outputs)

            x_new += d

        return self.FP(x_new)
 def __call__(self, vertex, edge, adj, num_array):
     if self.Wc.array is None:
         v_in_size = vertex.shape[1]
         self._initialize_params_v(v_in_size)
     if self.We.array is None:
         e_in_size = edge.shape[1]
         self._initialize_params_e(e_in_size)
     neighbor = F.matmul(vertex, self.Wn)
     neighbor = F.sparse_matmul(adj, neighbor) / num_array
     center = F.matmul(vertex, self.Wc)
     edge_feature = F.sparse_matmul(edge, self.We)
     length = int(np.sqrt(edge_feature.shape[0]))
     edge_feature = F.reshape(edge_feature,
                              [length, length, edge_feature.shape[1]])
     edge_feature = F.sum(edge_feature, axis=0) / num_array
     output = center + neighbor + edge_feature
     if self.b is not None:
         output += self.b
     return output, edge, adj, num_array
    def __call__(self, vertex, edge, adj, num_array):
        if self.Wc.array is None:
            v_in_size = vertex.shape[1]
            self._initialize_params_v(v_in_size)

        neighbor = F.matmul(vertex, self.Wn)
        neighbor = F.sparse_matmul(adj, neighbor) / num_array
        center = F.matmul(vertex, self.Wc)
        output = center + neighbor
        if self.residual:
            output = vertex + output
        if self.b is not None:
            output += self.b
        return output, edge, adj, num_array
Esempio n. 11
0
    def __call__(self, h, adj, **kwargs):
        """Describing a layer.

        Args:
            h (numpy.ndarray): minibatch by num_nodes by hidden_dim
                numpy array. local node hidden states
            adj (numpy.ndarray): minibatch by num_nodes by num_nodes 1/0 array.
                Adjacency matrices over several bond types

        Returns:
            updated h
        """
        # Support for one graph (node classification task)
        if h.ndim == 2:
            h = h[None]

        # (minibatch, atom, ch)
        mb, atom, ch = h.shape

        # --- Message part ---
        if isinstance(adj, chainer.utils.CooMatrix):
            # coo pattern
            # Support for one graph
            if adj.data.ndim == 1:
                adj.data = adj.data[None]
                adj.col = adj.col[None]
                adj.row = adj.row[None]
            fv = functions.sparse_matmul(adj, h)
        else:
            # padding pattern
            # adj (mb, atom, atom)
            # fv   (minibatch, atom, ch)
            fv = chainer_chemistry.functions.matmul(adj, h)
            assert (fv.shape == (mb, atom, ch))

        # sum myself
        sum_h = fv + h
        assert (sum_h.shape == (mb, atom, ch))

        # apply MLP
        new_h = self.graph_mlp(sum_h)
        new_h = functions.relu(new_h)
        if self.dropout_ratio > 0.0:
            new_h = functions.dropout(new_h, ratio=self.dropout_ratio)
        return new_h
Esempio n. 12
0
    def __call__(self, x, adj):
        if isinstance(x, chainer.utils.CooMatrix):
            x = copy.deepcopy(x)
            x_data = x.data
            z = []
            for i in range(self.n_heads):
                x.data = F.dropout(x_data, 0.6)
                z.append(F.sparse_matmul(x, self.W[0, i])[:, None, :])
            z = F.concat(z, axis=1)
        else:
            x = F.tile(x[:, None, :], (1, self.n_heads, 1))
            x = F.dropout(x, 0.6)
            z = F.squeeze(F.matmul(x[:, :, None, :], self.W), 2)
        output = self.attention(z, adj)

        if self.b is not None:
            output += self.b

        return output
Esempio n. 13
0
    def __call__(self, h, adj, **kwargs):
        hidden_ch = self.hidden_channels
        # --- Message part ---
        mb, atom, in_ch = h.shape
        m = functions.reshape(self.graph_linear(h),
                              (mb, atom, hidden_ch, self.n_edge_types))
        # m: (minibatch, atom, ch, edge_type)
        # Transpose
        m = functions.transpose(m, (0, 3, 1, 2))
        # m: (minibatch, edge_type, atom, ch)

        # (minibatch * edge_type, atom, out_ch)
        m = functions.reshape(m, (mb * self.n_edge_types, atom, hidden_ch))

        if is_sparse(adj):
            m = functions.sparse_matmul(adj, m)
        else:
            adj = functions.reshape(adj, (mb * self.n_edge_types, atom, atom))
            m = chainer_chemistry.functions.matmul(adj, m)

        # (minibatch * edge_type, atom, out_ch)
        m = functions.reshape(m, (mb, self.n_edge_types, atom, hidden_ch))
        m = functions.sum(m, axis=1)
        # (minibatch, atom, out_ch)

        # --- Update part ---
        # Contraction
        h = functions.reshape(h, (mb * atom, in_ch))

        # Contraction
        m = functions.reshape(m, (mb * atom, hidden_ch))

        out_h = self.update_layer(functions.concat((h, m), axis=1))
        # Expansion
        out_h = functions.reshape(out_h, (mb, atom, self.out_channels))
        return out_h
Esempio n. 14
0
    def __call__(self, batch_graph, targets=None):
        """
        This method performs forward calculation.

        Parameters
        ----------
        batch_graph : list consists of Graph
            contains Graphs in minibatch
        targets : targets
            this parameter is only used in regression task

        Returns
        -------
        In classification task : (batchsize, num_classes) matrix
            which means the probability of which class is each graph in.
        In regression task : (batchsize, 1) matrix
            which means the prediction value of each graph treewidth.
        """
        # set the array module based on using device
        xp = self.device.xp

        # concatenate the node_features
        X_concat = chainer.Variable(xp.concatenate([xp.array(graph.node_features) for graph in batch_graph], axis=0))
        X_concat.to_device(self.device)  # if you use GPU, you must transfer X_concat into GPU.

        # make graph pooling matrix and neighbors pooling matrix
        graph_pool = self.__preprocess_graphpool(batch_graph)
        if self.neighbor_pooling_type == "max":
            padded_neighbor_list = self.__preprocess_neighbors_maxpool(batch_graph)
        else:
            Adj_block = self.__preprocess_neighbors_sumavepool(batch_graph)

        hidden_rep = [X_concat]  # list of hidden representation at each layer (including input feature vectors)
        h = X_concat

        # perform Aggregating and Combining node features
        for layer in range(self.num_layers-1):
            # perform max neighbor pooling
            if self.neighbor_pooling_type == "max":
                # padding minimum value vector
                padded_h = F.concat((h, F.min(h, axis=0).reshape(1, h.shape[1])), axis=0)

                # make (F-dim, max_deg * nodes) matrix to perform max aggregation
                pooled_mat = F.sparse_matmul(padded_h.transpose(), padded_neighbor_list).transpose()

                # make 3D tensor
                pooled_tensor = F.reshape(pooled_mat, (padded_neighbor_list.shape[0] - 1,
                                          int(padded_neighbor_list.shape[1] / (padded_neighbor_list.shape[0] - 1)), h.shape[1]))

                # take max
                pooled = F.max(pooled_tensor, axis=1)

            # perform sum or average neighbor pooling
            else:
                pooled = F.sparse_matmul(Adj_block, h)
                if self.neighbor_pooling_type == "average":
                    degree = F.sparse_matmul(Adj_block, xp.ones((Adj_block.shape[0], 1), dtype=xp.float32))
                    pooled = pooled/degree

            # input aggregated vectors into MLP
            pooled_rep = self.mlps[layer](pooled)
            h = self.batch_norms[layer](pooled_rep)
            h = F.relu(h)
            hidden_rep.append(h)

        # perform Readout node features
        score_over_layer = 0
        for layer, h in enumerate(hidden_rep):
            # perform max readout
            if self.graph_pooling_type == "max":
                # padding minimum value
                padded_h = F.concat((h, F.min(h, axis=0).reshape(1, h.shape[1])), axis=0)

                # make (F-dim, max|V| * batchsize) matrix to perform max aggregation
                pooled_mat = F.sparse_matmul(padded_h.transpose(), graph_pool).transpose()

                # make 3D tensor
                pooled_tensor = F.reshape(pooled_mat, (len(batch_graph), int(graph_pool.shape[1] / len(batch_graph)), h.shape[1]))

                # take max
                pooled_h = F.max(pooled_tensor, axis=1)

            # sum or average readout
            else:
                pooled_h = F.sparse_matmul(graph_pool, h)

            score_over_layer += F.dropout(self.linears_prediction[layer](pooled_h), self.final_dropout)

        # final layers in regression task
        if self.task_type == "Regression":
            h = self.final_l2(score_over_layer)
            h = F.relu(h)
            score_over_layer = self.final_l1(h)

            if targets is None:
                return score_over_layer
            else:
                self.loss = F.mean_squared_error(targets.reshape(-1, 1), score_over_layer)  # MSE Loss
                self.abs_loss = F.mean_absolute_error(targets.reshape(-1, 1), score_over_layer)  # MAE Loss
                self.abs_max_loss = F.max(F.absolute_error(targets.reshape(-1, 1), score_over_layer))  # Max Absolute Error
                chainer.reporter.report({'loss': self.loss}, self)
                chainer.reporter.report({'abs_loss': self.abs_loss}, self)
                chainer.reporter.report({'abs_max_loss': self.abs_max_loss}, self)
                # return the MSE loss. If you want to use other loss, please change this sentence.
                return self.loss

        return score_over_layer
    def update_core(self):
        optimizer_sd = self.get_optimizer('main')
        optimizer_enc = self.get_optimizer('enc')
        optimizer_dec = self.get_optimizer('dec')
        optimizer_dis = self.get_optimizer('dis')
        xp = self.seed.xp

        step = self.iteration % self.args.iter
        osem_step = step % self.args.osem
        if step == 0:
            batch = self.get_iterator('main').next()
            self.prImg, self.rev, self.patient_id, self.slice = self.converter(batch, self.device)
            print(self.prImg.shape)
            self.n_reconst += 1
            self.recon_freq = 1
            if ".npy" in self.args.model_image:
                self.seed.W.array = xp.reshape(xp.load(self.args.model_image),(1,1,self.args.crop_height,self.args.crop_width))
            elif ".dcm" in self.args.model_image:
                ref_dicom = dicom.read_file(self.args.model_image, force=True)
                img = xp.array(ref_dicom.pixel_array+ref_dicom.RescaleIntercept)
                img = (2*(xp.clip(img,self.args.HU_base,self.args.HU_base+self.args.HU_range)-self.args.HU_base)/self.args.HU_range-1.0).astype(np.float32)
                self.seed.W.array = xp.reshape(img,(1,1,self.args.crop_height,self.args.crop_width))
            else:
#                initializers.Uniform(scale=0.5)(self.seed.W.array)
                initializers.HeNormal()(self.seed.W.array)
            self.initial_seed = self.seed.W.array.copy()
#            print(xp.min(self.initial_seed),xp.max(self.initial_seed),xp.mean(self.initial_seed))

        ## for seed array
        arr = self.seed()
        HU = self.var2HU(arr)
        raw = self.HU2raw(HU)

        self.seed.cleargrads()
        loss_seed = Variable(xp.array([0.0],dtype=np.float32))
        # conjugate correction using system matrix
        if self.args.lambda_sd > 0:
            self.seed.W.grad = xp.zeros_like(self.seed.W.array)
            loss_sd = 0
            for i in range(len(self.prImg)):
                if self.rev[i]:
                    rec_sd = F.exp(-F.sparse_matmul(self.prMats[osem_step],F.reshape(raw[i,:,::-1,::-1],(-1,1)))) ##
                else:
                    rec_sd = F.exp(-F.sparse_matmul(self.prMats[osem_step],F.reshape(raw[i],(-1,1)))) ##
                if self.args.log:
                    loss_sd += F.mean_squared_error(F.log(rec_sd),F.log(self.prImg[i][osem_step]))
                else:
                    loss_sd += F.mean_squared_error(rec_sd,self.prImg[i][osem_step])
                if self.args.system_matrix:
                    gd = F.sparse_matmul( self.conjMats[osem_step], rec_sd-self.prImg[i][osem_step], transa=True)
                    if self.rev[i]:
                        self.seed.W.grad[i] -= self.args.lambda_sd * F.reshape(gd, (1,self.args.crop_height,self.args.crop_width)).array[:,::-1,::-1]    # / logrep.shape[0] ?
                    else:
                        self.seed.W.grad[i] -= self.args.lambda_sd * F.reshape(gd, (1,self.args.crop_height,self.args.crop_width)).array    # / logrep.shape[0] ?

            if not self.args.system_matrix:
                (self.args.lambda_sd *loss_sd).backward()
            chainer.report({'loss_sd': loss_sd/len(self.prImg)}, self.seed)

        if self.args.lambda_tvs > 0:
            loss_tvs = losses.total_variation(arr, tau=self.args.tv_tau, method=self.args.tv_method)
            loss_seed += self.args.lambda_tvs * loss_tvs
            chainer.report({'loss_tvs': loss_tvs}, self.seed)

        if self.args.lambda_advs>0:
            L_advs = F.average( (self.dis(arr)-1.0)**2 )
            loss_seed += self.args.lambda_advs * L_advs
            chainer.report({'loss_advs': L_advs}, self.seed)

        ## generator output
        arr_n = losses.add_noise(arr,self.args.noise_gen)
        if self.args.no_train_seed:
            arr_n.unchain()
        if not self.args.decoder_only:
            arr_n = self.encoder(arr_n)
        gen = self.decoder(arr_n) # range = [-1,1]

        ## generator loss
        loss_gen = Variable(xp.array([0.0],dtype=np.float32))
        plan, plan_ae = None, None
        if self.args.lambda_ae1>0 or self.args.lambda_ae2>0:
            plan = losses.add_noise(Variable(self.converter(self.get_iterator('planct').next(), self.device)), self.args.noise_dis)
            plan_enc = self.encoder(plan)
            plan_ae = self.decoder(plan_enc)
            loss_ae1 = F.mean_absolute_error(plan,plan_ae)
            loss_ae2 = F.mean_squared_error(plan,plan_ae)
            if self.args.lambda_reg>0:
                loss_reg_ae = losses.loss_func_reg(plan_enc[-1],'l2')
                chainer.report({'loss_reg_ae': loss_reg_ae}, self.seed)
                loss_gen += self.args.lambda_reg * loss_reg_ae
            loss_gen += self.args.lambda_ae1 * loss_ae1 + self.args.lambda_ae2 * loss_ae2
            chainer.report({'loss_ae1': loss_ae1}, self.seed)
            chainer.report({'loss_ae2': loss_ae2}, self.seed)
        if self.args.lambda_tv > 0:
            L_tv = losses.total_variation(gen, tau=self.args.tv_tau, method=self.args.tv_method)
            loss_gen += self.args.lambda_tv * L_tv
            chainer.report({'loss_tv': L_tv}, self.seed)
        if self.args.lambda_adv>0:
            L_adv = F.average( (self.dis(gen)-1.0)**2 )
            loss_gen += self.args.lambda_adv * L_adv
            chainer.report({'loss_adv': L_adv}, self.seed)
        ## regularisation on the latent space
        if self.args.lambda_reg>0:
            loss_reg = losses.loss_func_reg(arr_n[-1],'l2')
            chainer.report({'loss_reg': loss_reg}, self.seed)
            loss_gen += self.args.lambda_reg * loss_reg

        self.encoder.cleargrads()
        self.decoder.cleargrads()
        loss_gen.backward()
        loss_seed.backward()
        chainer.report({'loss_gen': loss_gen}, self.seed)
        optimizer_enc.update()
        optimizer_dec.update()
        optimizer_sd.update()

        chainer.report({'grad_sd': F.average(F.absolute(self.seed.W.grad))}, self.seed)
        if hasattr(self.decoder, 'latent_fc'):
            chainer.report({'grad_gen': F.average(F.absolute(self.decoder.latent_fc.W.grad))}, self.seed)

        # reconstruction consistency for NN
        if (step % self.recon_freq == 0) and self.args.lambda_nn>0:
            self.encoder.cleargrads()
            self.decoder.cleargrads()
            self.seed.cleargrads()
            gen.grad = xp.zeros_like(gen.array)

            HU_nn = self.var2HU(gen)
            raw_nn = self.HU2raw(HU_nn)
            loss_nn = 0
            for i in range(len(self.prImg)):
                if self.rev[i]:
                    rec_nn = F.exp(-F.sparse_matmul(self.prMats[osem_step],F.reshape(raw_nn[i,:,::-1,::-1],(-1,1))))
                else:
                    rec_nn = F.exp(-F.sparse_matmul(self.prMats[osem_step],F.reshape(raw_nn[i],(-1,1))))
                loss_nn += F.mean_squared_error(rec_nn,self.prImg[i][osem_step])
                if self.args.system_matrix:
                    gd_nn = F.sparse_matmul( rec_nn-self.prImg[i][osem_step], self.conjMats[osem_step], transa=True )
                    if self.rev[i]:
                        gen.grad[i] -= self.args.lambda_nn * F.reshape(gd_nn, (1,self.args.crop_height,self.args.crop_width)).array[:,::-1,::-1]
                    else:
                        gen.grad[i] -= self.args.lambda_nn * F.reshape(gd_nn, (1,self.args.crop_height,self.args.crop_width)).array
            chainer.report({'loss_nn': loss_nn/len(self.prImg)}, self.seed)
            if self.args.system_matrix:
                gen.backward()
            else:
                (self.args.lambda_nn * loss_nn).backward()

            if not self.args.no_train_seed:
                optimizer_sd.update()
            if not self.args.no_train_enc:
                optimizer_enc.update()
            if not self.args.no_train_dec:
                optimizer_dec.update()

            if self.seed.W.grad is not None:
                chainer.report({'grad_sd_consistency': F.average(F.absolute(self.seed.W.grad))}, self.seed)
            if hasattr(self.decoder, 'latent_fc'):
                chainer.report({'grad_gen_consistency': F.average(F.absolute(self.decoder.latent_fc.W.grad))}, self.seed)
            elif hasattr(self.decoder, 'ul'):
                chainer.report({'grad_gen_consistency': F.average(F.absolute(self.decoder.ul.c1.c.W.grad))}, self.seed)

        chainer.report({'seed_diff': F.mean_absolute_error(self.initial_seed,self.seed.W)/F.mean_absolute_error(self.initial_seed,xp.zeros_like(self.initial_seed))}, self.seed)

        # clip seed to [-1,1]
        if self.args.clip:
            self.seed.W.array = xp.clip(self.seed.W.array,a_min=-1.0, a_max=1.0)

        # adjust consistency loss update frequency
        self.recon_freq = max(1,int(round(self.args.max_reconst_freq * (step-self.args.reconst_freq_decay_start) / (self.args.iter+1-self.args.reconst_freq_decay_start))))

        ## for discriminator
        fake = None
        if self.args.dis_freq > 0 and ( (step+1) % self.args.dis_freq == 0) and (self.args.lambda_gan+self.args.lambda_adv+self.args.lambda_advs>0):
            # get mini-batch
            if plan is None:
                plan = self.converter(self.get_iterator('planct').next(), self.device)
                plan = losses.add_noise(Variable(plan),self.args.noise_dis)
            
            # create fake
            if self.args.lambda_gan>0:
                if self.args.decoder_only:
                    fake_seed = xp.random.uniform(-1,1,(1,self.args.latent_dim)).astype(np.float32)
                else:
                    fake_seed = self.encoder(xp.random.uniform(-1,1,(1,1,self.args.crop_height,self.args.crop_width)).astype(np.float32))
                fake = self.decoder(fake_seed)
                # decoder
                self.decoder.cleargrads()
                loss_gan = F.average( (self.dis(fake)-1.0)**2 )
                chainer.report({'loss_gan': loss_gan}, self.seed)
                loss_gan *= self.args.lambda_gan
                loss_gan.backward()
                optimizer_dec.update(loss=loss_gan)
                fake_copy = self._buffer.query(fake.array)
            if self.args.lambda_nn>0:
                fake_copy = self._buffer.query(self.converter(self.get_iterator('mvct').next(), self.device))
            if (step+1) % (self.args.iter // 30):
                fake_copy = Variable(self._buffer.query(gen.array))
            # discriminator
            L_real = F.average( (self.dis(plan)-1.0)**2 )
            L_fake = F.average( self.dis(fake_copy)**2 )
            loss_dis = 0.5*(L_real+L_fake)
            self.dis.cleargrads()
            loss_dis.backward()
            optimizer_dis.update()
            chainer.report({'loss_dis': (L_real+L_fake)/2}, self.seed)


        if ((self.iteration+1) % self.args.vis_freq == 0) or  ((step+1)==self.args.iter):
            for i in range(self.args.batchsize):
                outlist=[]
                if not self.args.no_train_seed and not self.args.decoder_only:
                    outlist.append((self.seed()[i],"0sd"))
                if plan_ae is not None:
                    outlist.append((plan[i],'2pl'))
                    outlist.append((plan_ae[i],'3ae'))
                if self.args.lambda_nn>0 or self.args.lambda_adv>0:
                    if self.args.decoder_only:
                        gen_img = self.decoder([self.seed()])
                    else:
                        gen_img = self.decoder(self.encoder(self.seed()))
                    outlist.append((gen_img[i],'1gn'))
                if fake is not None:
                    outlist.append((fake[i],'4fa'))
                for out,typ in outlist:
                    out.to_cpu()
                    HU = (((out+1)/2 * self.args.HU_range)+self.args.HU_base).array  # [-1000=air,0=water,>1000=bone]
                    print("type: ",typ,"HU:",np.min(HU),np.mean(HU),np.max(HU))
                    #visimg = np.clip((out.array+1)/2,0,1) * 255.0
                    b,r = -self.args.HU_range_vis//2,self.args.HU_range_vis
                    visimg = (np.clip(HU,b,b+r)-b)/r * 255.0
                    fn = 'n{:0>5}_iter{:0>6}_p{}_z{}_{}'.format(self.n_reconst,step+1,self.patient_id[i],self.slice[i],typ)
                    write_image(np.uint8(visimg),os.path.join(self.args.out,fn+'.jpg'))
                    if (step+1)==self.args.iter or (not self.args.no_save_dcm):
                        #np.save(os.path.join(self.args.out,fn+'.npy'),HU[0])
                        write_dicom(os.path.join(self.args.out,fn+'.dcm'),HU[0])
Esempio n. 16
0
 def check_SPDN_forward(self, a_data, b_data, atol=1e-4, rtol=1e-5):
     sp_a = utils.to_coo(a_data, requires_grad=True)
     b = chainer.Variable(b_data)
     c = F.sparse_matmul(sp_a, b, transa=self.transa, transb=self.transb)
     testing.assert_allclose(self.forward_answer, c.data, atol, rtol)
Esempio n. 17
0
 def check_DNSP_forward(self, a_data, b_data, atol=1e-4, rtol=1e-5):
     a = chainer.Variable(a_data)
     sp_b = utils.to_coo(b_data, requires_grad=True)
     c = F.sparse_matmul(a, sp_b, transa=self.transa, transb=self.transb)
     testing.assert_allclose(self.forward_answer, c.data, atol, rtol)