예제 #1
0
def pretrain_dsm(model, t_train, e_train, t_valid, e_valid,
                 n_iter=10000, lr=1e-2, thres=1e-4):

  premodel = DeepSurvivalMachinesTorch(1, 1,
                                       dist=model.dist)
  premodel.double()

  optimizer = torch.optim.Adam(premodel.parameters(), lr=lr)
  oldcost = float('inf')
  patience = 0

  costs = []
  for _ in tqdm(range(n_iter)):

    optimizer.zero_grad()

    loss = unconditional_loss(premodel, t_train, e_train)
    loss.backward()
    optimizer.step()

    valid_loss = unconditional_loss(premodel, t_valid, e_valid)
    valid_loss = valid_loss.detach().cpu().numpy()

    costs.append(valid_loss)

    if np.abs(costs[-1] - oldcost) < thres:
      patience += 1
      if patience == 3:
        break
    oldcost = costs[-1]

  return premodel
예제 #2
0
def pretrain_dsm(model,
                 t_train,
                 e_train,
                 t_valid,
                 e_valid,
                 n_iter=10000,
                 lr=1e-2,
                 thres=1e-4,
                 cuda=False):

    premodel = DeepSurvivalMachinesTorch(1,
                                         1,
                                         dist=model.dist,
                                         risks=model.risks,
                                         optimizer=model.optimizer).double()

    if cuda:
        premodel.cuda()
        t_train, e_train = t_train.cuda(), e_train.cuda()
        t_valid, e_valid = t_valid.cuda(), e_valid.cuda()

    optimizer = get_optimizer(premodel, lr)

    oldcost = float('inf')
    patience = 0
    costs = []
    for _ in tqdm(range(n_iter)):

        optimizer.zero_grad()
        loss = 0
        for r in range(model.risks):
            loss += unconditional_loss(premodel, t_train, e_train, str(r + 1))
        loss.backward()
        optimizer.step()

        valid_loss = 0
        for r in range(model.risks):
            valid_loss += unconditional_loss(premodel, t_valid, e_valid,
                                             str(r + 1))
        valid_loss = valid_loss.detach().cpu().numpy()
        costs.append(valid_loss)
        #print(valid_loss)
        if np.abs(costs[-1] - oldcost) < thres:
            patience += 1
            if patience == 3:
                break
        oldcost = costs[-1]

    return premodel