Exemplo n.º 1
0
def eval_omics_net(x_tr, age_tr, y_tr, delta_tr, \
            x_va, age_va, y_va, delta_va, \
            x_te, age_te, y_te, delta_te, \
            gene_indices, pathway_indices, \
            in_nodes, gene_nodes, pathway_nodes, hidden_nodes, \
            LR, L2, max_epochs, dropout_rate, step = 100, tolerance = 0.02, sparse_coding = False):

    net = omics_net(in_nodes, gene_nodes, pathway_nodes, hidden_nodes)
    ###if gpu is being used
    if torch.cuda.is_available():
        net = net.cuda()
    ###optimizer
    opt = optim.Adam(net.parameters(), lr=LR, weight_decay = L2)

    prev_sum = 0.0
    for epoch in range(max_epochs):
        net.train()
        ###reset gradients to zeros
        opt.zero_grad() 
        ###Randomize dropout masks
        net.do_m1 = dropout_mask(pathway_nodes, dropout_rate[0])
        net.do_m2 = dropout_mask(hidden_nodes[0], dropout_rate[1])
        ###Forward
        pred = net(x_tr, age_tr, gene_indices, pathway_indices, dropout_rate)
        ###calculate loss
        loss = neg_par_log_likelihood(pred, y_tr, delta_tr)
        ###calculate gradients
        loss.backward() 
        ###force the connections between omics layer and gene layer w.r.t. 'gene_mask'
        net.omics.weight.grad = fixed_s_mask(net.omics.weight.grad, gene_indices)
        ###force the connections between gene layer and pathway layer w.r.t. 'pathway_mask'
        net.gene.weight.grad = fixed_s_mask(net.gene.weight.grad, pathway_indices)
        ###update weights and biases
        opt.step()
        if sparse_coding == True:
            net = sparse_func(net, x_tr, age_tr, y_tr, delta_tr, gene_indices, pathway_indices, dropout_rate)
        if epoch % step == step - 1:
            net.train()
            pred = net(x_tr, age_tr, gene_indices, pathway_indices, dropout_rate)
            train_cindex = c_index(pred.cpu(), y_tr.cpu(), delta_tr.cpu())
            net.eval()
            pred = net(x_va, age_va, gene_indices, pathway_indices, dropout_rate)
            eval_cindex = c_index(pred.cpu(), y_va.cpu(), delta_va.cpu())
            if ((eval_cindex.item() + train_cindex.item() + tolerance) < prev_sum): 
                print('Early stopping in [%d]' % (epoch + 1))
                print('[%d] Best CIndex in Train: %.3f' % (epoch + 1, opt_cidx_tr))
                print('[%d] Best CIndex in Valid: %.3f' % (epoch + 1, opt_cidx_va))
                opt_net.eval()
                pred = opt_net(x_te, age_te, gene_indices, pathway_indices, dropout_rate)
                eval_cindex = c_index(pred.cpu(), y_te.cpu(), delta_te.cpu())
                break
            else:
                prev_sum = eval_cindex.item() + train_cindex.item()
                opt_cidx_tr = train_cindex
                opt_cidx_va = eval_cindex
                opt_net = copy.deepcopy(net)
                print('[%d] CIndex in Train: %.3f' % (epoch + 1, train_cindex))
                print('[%d] CIndex in Valid: %.3f' % (epoch + 1, eval_cindex))

    return (opt_cidx_tr, opt_cidx_va, eval_cindex)
Exemplo n.º 2
0
def sparse_func(net, x_tr, age_tr, y_tr, delta_tr, Gene_Indices,
                Pathway_Indices, Dropout_Rate):
    '''Sparse coding phrase: optimize the connections between intermediate layers sequentially'''
    ###serializing net
    net_state_dict = net.state_dict()
    ###make a copy for net, and then optimize sparsity level via copied net
    copy_net = copy.deepcopy(net)
    copy_state_dict = copy_net.state_dict()
    for name, param in net_state_dict.items():
        ###omit the param if it is not a weight matrix
        if not "weight" in name: continue
        if "omics" in name: continue
        if "gene" in name: continue
        if "hidden2" in name: continue
        if "bn1" in name: continue
        if "bn2" in name: continue
        if "bn3" in name: continue
        if "bn4" in name: continue
        if "pathway" in name:
            active_mask = small_net_mask(net.pathway.weight.data, net.do_m1,
                                         net.do_m2)
            copy_weight = copy.deepcopy(net.pathway.weight.data)
        if "hidden" in name:
            active_mask = small_net_mask(net.hidden.weight.data, net.do_m2,
                                         net.do_m3)
            copy_weight = copy.deepcopy(net.hidden.weight.data)
        S_set = torch.linspace(99, 0, 5)
        S_loss = []
        for S in S_set:
            sp_param = get_sparse_weight(copy_weight, active_mask, S.item())
            copy_state_dict[name].copy_(sp_param)
            copy_net.train()
            y_tmp = copy_net(x_tr, age_tr, Gene_Indices, Pathway_Indices,
                             Dropout_Rate)
            loss_tmp = neg_par_log_likelihood(y_tmp, y_tr, delta_tr)
            S_loss.append(loss_tmp)
        ###apply cubic interpolation
        best_S = get_best_sparsity(S_set, S_loss)
        best_epsilon = get_threshold(copy_weight, active_mask, best_S)
        optimal_sp_param = soft_threshold(copy_weight, best_epsilon)
        copy_weight[active_mask] = optimal_sp_param[active_mask]
        ###update weights in copied net
        copy_state_dict[name].copy_(copy_weight)
        ###update weights in net
        net_state_dict[name].copy_(copy_weight)
    return (net)
Exemplo n.º 3
0
def trainCoxPASNet(train_x, train_age, train_ytime, train_yevent, \
   eval_x, eval_age, eval_ytime, eval_yevent, pathway_mask, \
   In_Nodes, Pathway_Nodes, Hidden_Nodes, Out_Nodes, \
   Learning_Rate, L2, Num_Epochs, Dropout_Rate):

    net = Cox_PASNet(In_Nodes, Pathway_Nodes, Hidden_Nodes, Out_Nodes,
                     pathway_mask)
    ###if gpu is being used
    if torch.cuda.is_available():
        net.cuda()
    ###
    ###optimizer
    opt = optim.Adam(net.parameters(), lr=Learning_Rate, weight_decay=L2)

    for epoch in range(Num_Epochs + 1):
        net.train()
        opt.zero_grad()  ###reset gradients to zeros
        ###Randomize dropout masks
        net.do_m1 = dropout_mask(Pathway_Nodes, Dropout_Rate[0])
        net.do_m2 = dropout_mask(Hidden_Nodes, Dropout_Rate[1])

        pred = net(train_x, train_age)  ###Forward
        loss = neg_par_log_likelihood(pred, train_ytime,
                                      train_yevent)  ###calculate loss
        loss.backward()  ###calculate gradients
        opt.step()  ###update weights and biases

        net.sc1.weight.data = net.sc1.weight.data.mul(
            net.pathway_mask
        )  ###force the connections between gene layer and pathway layer

        ###obtain the small sub-network's connections
        do_m1_grad = copy.deepcopy(net.sc2.weight._grad.data)
        do_m2_grad = copy.deepcopy(net.sc3.weight._grad.data)
        do_m1_grad_mask = torch.where(do_m1_grad == 0, do_m1_grad,
                                      torch.ones_like(do_m1_grad))
        do_m2_grad_mask = torch.where(do_m2_grad == 0, do_m2_grad,
                                      torch.ones_like(do_m2_grad))
        ###copy the weights
        net_sc2_weight = copy.deepcopy(net.sc2.weight.data)
        net_sc3_weight = copy.deepcopy(net.sc3.weight.data)

        ###serializing net
        net_state_dict = net.state_dict()

        ###Sparse Coding
        ###make a copy for net, and then optimize sparsity level via copied net
        copy_net = copy.deepcopy(net)
        copy_state_dict = copy_net.state_dict()
        for name, param in copy_state_dict.items():
            ###omit the param if it is not a weight matrix
            if not "weight" in name:
                continue
            ###omit gene layer
            if "sc1" in name:
                continue
            ###stop sparse coding
            if "sc4" in name:
                break
            ###sparse coding between the current two consecutive layers is in the trained small sub-network
            if "sc2" in name:
                active_param = net_sc2_weight.mul(do_m1_grad_mask)
            if "sc3" in name:
                active_param = net_sc3_weight.mul(do_m2_grad_mask)
            nonzero_param_1d = active_param[active_param != 0]
            if nonzero_param_1d.size(
                    0
            ) == 0:  ###stop sparse coding between the current two consecutive layers if there are no valid weights
                break
            copy_param_1d = copy.deepcopy(nonzero_param_1d)
            ###set up potential sparsity level in [0, 100)
            S_set = torch.arange(100, -1, -1)[1:]
            copy_param = copy.deepcopy(active_param)
            S_loss = []
            for S in S_set:
                param_mask = s_mask(sparse_level=S.item(),
                                    param_matrix=copy_param,
                                    nonzero_param_1D=copy_param_1d,
                                    dtype=dtype)
                transformed_param = copy_param.mul(param_mask)
                copy_state_dict[name].copy_(transformed_param)
                copy_net.train()
                y_tmp = copy_net(train_x, train_age)
                loss_tmp = neg_par_log_likelihood(y_tmp, train_ytime,
                                                  train_yevent)
                S_loss.append(loss_tmp)
            ###apply cubic interpolation
            interp_S_loss = interp1d(S_set, S_loss, kind='cubic')
            interp_S_set = torch.linspace(min(S_set), max(S_set), steps=100)
            interp_loss = interp_S_loss(interp_S_set)
            optimal_S = interp_S_set[np.argmin(interp_loss)]
            optimal_param_mask = s_mask(sparse_level=optimal_S.item(),
                                        param_matrix=copy_param,
                                        nonzero_param_1D=copy_param_1d,
                                        dtype=dtype)
            if "sc2" in name:
                final_optimal_param_mask = torch.where(
                    do_m1_grad_mask == 0, torch.ones_like(do_m1_grad_mask),
                    optimal_param_mask)
                optimal_transformed_param = net_sc2_weight.mul(
                    final_optimal_param_mask)
            if "sc3" in name:
                final_optimal_param_mask = torch.where(
                    do_m2_grad_mask == 0, torch.ones_like(do_m2_grad_mask),
                    optimal_param_mask)
                optimal_transformed_param = net_sc3_weight.mul(
                    final_optimal_param_mask)
            ###update weights in copied net
            copy_state_dict[name].copy_(optimal_transformed_param)
            ###update weights in net
            net_state_dict[name].copy_(optimal_transformed_param)

        if epoch % 200 == 0:
            net.train()
            train_pred = net(train_x, train_age)
            train_loss = neg_par_log_likelihood(train_pred, train_ytime,
                                                train_yevent).view(1, )

            net.eval()
            eval_pred = net(eval_x, eval_age)
            eval_loss = neg_par_log_likelihood(eval_pred, eval_ytime,
                                               eval_yevent).view(1, )

            train_cindex = c_index(train_pred, train_ytime, train_yevent)
            eval_cindex = c_index(eval_pred, eval_ytime, eval_yevent)
            print("Loss in Train: ", train_loss)

    return (train_loss, eval_loss, train_cindex, eval_cindex)