class Model(NetInterface): @classmethod def add_arguments(cls, parser): parser.add_argument('--net_N', type=int, default=128, help="Number of neurons in hidden layers") parser.add_argument( '--x_dim', type=int, nargs='+', default=1000, help= "dimension of input features, take one or more arguments for integration" ) parser.add_argument('--z_dim', type=int, default=100, help="dimension of hidden variables") parser.add_argument('--n_classes', type=int, default=1, help="number of nodes for classification") parser.add_argument('--p_drop', type=float, default=0.2, help="probability of dropout") return parser, set() def __init__(self, opt, logger): super().__init__(opt, logger) self.net_N = opt.net_N self.x_dim = opt.x_dim self.z_dim = opt.z_dim self.n_classes = opt.n_classes self.p_drop = opt.p_drop self.modality = opt.modality self.n_modality = len(self.modality) assert len(self.x_dim) == self.n_modality # check the length of x_dim # init networks self.net_q = [None] * self.n_modality self.net_p = [None] * self.n_modality for i in range(self.n_modality): self.net_q[i] = Q_net(self.net_N, self.x_dim[i], self.z_dim, self.p_drop) self.net_p[i] = P_net(self.net_N, self.x_dim[i], self.z_dim, self.p_drop) self.net_c = C_net(self.net_N, self.z_dim * self.n_modality, self.n_classes, self.p_drop) self._nets = self.net_q + self.net_p + [self.net_c] # optimizers self.optimizer_q = [None] * self.n_modality self.optimizer_p = [None] * self.n_modality for i in range(self.n_modality): self.optimizer_q[i] = self.optimizer(self.net_q[i].parameters(), lr=opt.lr, **self.optimizer_params) self.optimizer_p[i] = self.optimizer(self.net_p[i].parameters(), lr=opt.lr, **self.optimizer_params) self.optimizer_c = self.optimizer(self.net_c.parameters(), lr=opt.lr, **self.optimizer_params) self._optimizers = self.optimizer_q + self.optimizer_p + [ self.optimizer_c ] # scheduler self.scheduler_q = [None] * self.n_modality self.scheduler_p = [None] * self.n_modality for i in range(self.n_modality): self.scheduler_q[i] = self.scheduler(self.optimizer_q[i], **self.scheduler_params) self.scheduler_p[i] = self.scheduler(self.optimizer_p[i], **self.scheduler_params) self.scheduler_c = self.scheduler(self.optimizer_c, **self.scheduler_params) self._schedulers = self.scheduler_q + self.scheduler_p + [ self.scheduler_c ] # general self.opt = opt self._metrics = [ 'loss_mse', 'loss_clf' ] # log the autoencoder loss and the classification loss if opt.log_time: self._metrics += ['t_mse', 't_clf'] # init variables #self.input_names = ['features', 'labels'] self.init_vars(add_path=True) # init weights for i in range(self.n_modality): self.init_weight(self.net_q[i]) self.init_weight(self.net_p[i]) self.init_weight(self.net_c) def __str__(self): s = "MNIST Simulated Autoencoder Concat" return s def _train_on_batch(self, epoch, batch_idx, batch): net_q, net_p, net_c = self.net_q, self.net_p, self.net_c opt_q, opt_p, opt_c = self.optimizer_q, self.optimizer_p, self.optimizer_c for i in range(self.n_modality): net_q[i].train() net_p[i].train() net_c.train() X_list = [None] * self.n_modality X_recon_list = [None] * self.n_modality z_list = [None] * self.n_modality loss_mse_list = [None] * self.n_modality # read in training data for i in range(self.n_modality): X_list[i] = batch[self.modality[i]] X_list[i] = X_list[i].view(X_list[i].shape[0], -1) X_list = [tmp.cuda() for tmp in X_list] y = batch['labels'].cuda() batch_size = X_list[0].shape[0] batch_log = {'size': batch_size} ################################################### # Stage 1: train Q and P with reconstruction loss # ################################################### for i in range(self.n_modality): net_q[i].zero_grad() net_p[i].zero_grad() for p in net_q[i].parameters(): p.requires_grad = True for p in net_p[i].parameters(): p.requires_grad = True for p in net_c.parameters(): p.requires_grad = False t0 = time() for i in range(self.n_modality): z_list[i] = net_q[i](X_list[i]) X_recon_list[i] = net_p[i](z_list[i]) loss_mse = nn.functional.mse_loss(X_recon_list[i], X_list[i]) # Mean square error loss_mse_list[i] = loss_mse.item() loss_mse.backward() opt_q[i].step() opt_p[i].step() t_mse = time() - t0 batch_log['loss_mse'] = sum(loss_mse_list) #################################################### # Stage 2: train Q and C with classification loss # #################################################### for i in range(self.n_modality): net_q[i].zero_grad() for p in net_q[i].parameters(): p.requires_grad = True for p in net_p[i].parameters(): p.requires_grad = False net_c.zero_grad() for p in net_c.parameters(): p.requires_grad = True t0 = time() for i in range(self.n_modality): z_list[i] = net_q[i](X_list[i]) #z_combined = torch.mean(torch.stack(z_list), dim=0) # get the mean of z_list z_combined = torch.cat(z_list, dim=1) pred = net_c(z_combined) criterion = nn.CrossEntropyLoss() loss_clf = criterion(pred, y) loss_clf.backward() for i in range(self.n_modality): opt_q[i].step() opt_c.step() t_clf = time() - t0 batch_log['loss_clf'] = loss_clf.item() if self.opt.log_time: batch_log['t_mse'] = t_mse batch_log['t_clf'] = t_clf return batch_log def _vali_on_batch(self, epoch, batch_idx, batch): for i in range(self.n_modality): self.net_q[i].eval() self.net_c.eval() X_list = [None] * self.n_modality z_list = [None] * self.n_modality # read in training data for i in range(self.n_modality): X_list[i] = batch[self.modality[i]] X_list[i] = X_list[i].view(X_list[i].shape[0], -1) X_list = [tmp.cuda() for tmp in X_list] y = batch['labels'].cuda() batch_size = X_list[0].shape[0] batch_log = {'size': batch_size} with torch.no_grad(): for i in range(self.n_modality): z_list[i] = self.net_q[i](X_list[i]) z_combined = torch.cat(z_list, dim=1) pred = self.net_c(z_combined) criterion = nn.CrossEntropyLoss() loss_clf = criterion(pred, y) batch_log['loss_clf'] = loss_clf.item() batch_log['loss'] = loss_clf.item() return batch_log
class Model(NetInterface): @classmethod def add_arguments(cls, parser): parser.add_argument('--net_N', type=int, default=128, help="Number of neurons in hidden layers") parser.add_argument('--x_dim', type=int, nargs='+', default=1000, help="dimension of input features") parser.add_argument('--z_dim', type=int, default=100, help="dimension of hidden variables") parser.add_argument('--n_classes', type=int, default=1, help="number of nodes for classification") parser.add_argument('--p_drop', type=float, default=0.2, help="probability of dropout") parser.add_argument('--clf_weights', type=float, nargs='+', default=0.7, help="classification weight for each class") return parser, set() def __init__(self, opt, logger): super().__init__(opt, logger) self.net_N = opt.net_N self.x_dim = opt.x_dim self.z_dim = opt.z_dim self.n_classes = opt.n_classes self.p_drop = opt.p_drop self.modality = opt.modality self.n_modality = len(self.modality) self.clf_weights = opt.clf_weights assert len(self.x_dim) == self.n_modality # check the length of x_dim # init networks self.net_q = [None] * self.n_modality for i in range(self.n_modality): self.net_q[i] = Q_net(self.net_N, self.x_dim[i], self.z_dim, self.p_drop) self.net_c = C_net(self.net_N, self.z_dim * self.n_modality, self.n_classes, self.p_drop) self._nets = self.net_q + [self.net_c] # optimizers self.optimizer_q = [None] * self.n_modality for i in range(self.n_modality): self.optimizer_q[i] = self.optimizer(self.net_q[i].parameters(), lr=opt.lr, **self.optimizer_params) self.optimizer_c = self.optimizer(self.net_c.parameters(), lr=opt.lr, **self.optimizer_params) self._optimizers = self.optimizer_q + [self.optimizer_c] # scheduler self.scheduler_q = [None] * self.n_modality for i in range(self.n_modality): self.scheduler_q[i] = self.scheduler(self.optimizer_q[i], **self.scheduler_params) self.scheduler_c = self.scheduler(self.optimizer_c, **self.scheduler_params) self._schedulers = self.scheduler_q + [self.scheduler_c] # general self.opt = opt self._metrics = [ 'loss_survival', 'c_index' ] # log the autoencoder loss and the classification loss if opt.log_time: self._metrics += ['t_survival'] # init variables #self.input_names = ['features', 'labels'] self.init_vars(add_path=True) # init weights for i in range(self.n_modality): self.init_weight(self.net_q[i]) self.init_weight(self.net_c) def __str__(self): s = "Autoencoder Concat" return s def _train_on_batch(self, epoch, batch_idx, batch): net_q, net_c = self.net_q, self.net_c opt_q, opt_c = self.optimizer_q, self.optimizer_c for i in range(self.n_modality): net_q[i].train() net_c.train() X_list = [None] * self.n_modality z_list = [None] * self.n_modality # read in training data for i in range(self.n_modality): X_list[i] = batch[self.modality[i]] X_list = [tmp.cuda() for tmp in X_list] survival_time = batch['days'].cuda() survival_event = batch['event'].cuda() batch_size = X_list[0].shape[0] batch_log = {'size': batch_size} #################################################### # Stage 2: train Q and C with classification loss # #################################################### if not survival_event.sum( 0): # skip the batch if all instances are negative batch_log['loss_survival'] = torch.Tensor([float('nan')]) return batch_log for i in range(self.n_modality): net_q[i].zero_grad() for p in net_q[i].parameters(): p.requires_grad = True net_c.zero_grad() for p in net_c.parameters(): p.requires_grad = True t0 = time() for i in range(self.n_modality): z_list[i] = net_q[i](X_list[i]) z_combined = torch.cat(z_list, dim=1) pred = net_c(z_combined) loss_survival = neg_par_log_likelihood(pred, survival_time, survival_event) loss_survival.backward() for i in range(self.n_modality): opt_q[i].step() opt_c.step() t_survival = time() - t0 c_index = CIndex(pred, survival_event, survival_time) batch_log['loss_survival'] = loss_survival.item() batch_log['c_index'] = c_index if self.opt.log_time: batch_log['t_survival'] = t_survival return batch_log def _vali_on_batch(self, epoch, batch_idx, batch): for i in range(self.n_modality): self.net_q[i].eval() self.net_c.eval() X_list = [None] * self.n_modality z_list = [None] * self.n_modality # read in training data for i in range(self.n_modality): X_list[i] = batch[self.modality[i]] X_list = [tmp.cuda() for tmp in X_list] survival_time = batch['days'].cuda() survival_event = batch['event'].cuda() batch_size = X_list[0].shape[0] batch_log = {'size': batch_size} if not survival_event.sum( 0): # skip the batch if all instances are negative batch_log['loss_survival'] = torch.Tensor([float('nan')]) return batch_log with torch.no_grad(): for i in range(self.n_modality): z_list[i] = self.net_q[i](X_list[i]) z_combined = torch.cat(z_list, dim=1) pred = self.net_c(z_combined) loss_survival = neg_par_log_likelihood(pred, survival_time, survival_event) c_index = CIndex(pred, survival_event, survival_time) batch_log['loss'] = loss_survival.item() batch_log['loss_survival'] = loss_survival.item() batch_log['c_index'] = c_index return batch_log
class Model(NetInterface): @classmethod def add_arguments(cls, parser): parser.add_argument('--net_N', type=int, default=128, help="Number of neurons in hidden layers") parser.add_argument('--x_dim', type=int, default=1000, help="dimension of input features") parser.add_argument('--z_dim', type=int, default=100, help="dimension of hidden variables") parser.add_argument('--n_classes', type=int, default=1, help="number of nodes for classification") parser.add_argument('--p_drop', type=float, default=0.2, help="probability of dropout") parser.add_argument('--clf_weights', type=float, nargs='+', default=0.7, help="classification weight for each class") return parser, set() def __init__(self, opt, logger): super().__init__(opt, logger) self.net_N = opt.net_N self.x_dim = opt.x_dim self.z_dim = opt.z_dim self.n_classes = opt.n_classes self.p_drop = opt.p_drop self.modality = opt.modality[ 0] # get the fisrt element since changed this argument to list self.clf_weights = opt.clf_weights # init networks self.net_q = Q_net(self.net_N, self.x_dim, self.z_dim, self.p_drop) self.net_c = C_net(self.net_N, self.z_dim, self.n_classes, self.p_drop) self._nets = [self.net_q, self.net_c] # optimizers # self.optimizer and self.optimizer_params have been initialized in netinterface self.optimizer_q = self.optimizer(self.net_q.parameters(), lr=opt.lr, **self.optimizer_params) self.optimizer_c = self.optimizer(self.net_c.parameters(), lr=opt.lr, **self.optimizer_params) self._optimizers = [self.optimizer_q, self.optimizer_c] # schedulers # self.scheduler and self.scheduler_params have been initialized in netinterface self.scheduler_q = self.scheduler(self.optimizer_q, **self.scheduler_params) self.scheduler_c = self.scheduler(self.optimizer_c, **self.scheduler_params) self._schedulers = [self.scheduler_q, self.scheduler_c] # general self.opt = opt self._metrics = [ 'loss_clf' ] # log the autoencoder loss and the classification loss if opt.log_time: self._metrics += ['t_clf'] # init variables self.init_vars(add_path=True) # init weights self.init_weight(self.net_q) self.init_weight(self.net_c) def __str__(self): s = "Autoencoder" return s def _train_on_batch(self, epoch, batch_idx, batch): net_q, net_c = self.net_q, self.net_c opt_q, opt_c = self.optimizer_q, self.optimizer_c net_q.train() net_c.train() X = batch[self.modality].cuda() #print(X.shape) # (batchsize, 1, 28, 28) # Flatten the X from 28x28 to 784 X = X.view(X.shape[0], -1) #print(X.shape) # (batchsize, 784) y = batch['labels'].cuda() batch_size = X.shape[0] batch_log = {'size': batch_size} net_q.zero_grad() net_c.zero_grad() for p in net_q.parameters(): p.requires_grad = True for p in net_c.parameters(): p.requires_grad = True t0 = time() z = net_q(X) pred = net_c(z) criterion = nn.CrossEntropyLoss() # the cross-entropy loss combines the 1)LogSoftmax() and 2) Negative log-likelihood loss NLLLoss loss_clf = criterion(pred, y) loss_clf.backward() opt_q.step() opt_c.step() t_clf = time() - t0 batch_log['loss_clf'] = loss_clf.item() if self.opt.log_time: batch_log['t_clf'] = t_clf return batch_log def _vali_on_batch(self, epoch, batch_idx, batch): self.net_q.eval() self.net_c.eval() X = batch[self.modality].cuda() X = X.view(X.shape[0], -1) # flatten the X y = batch['labels'].cuda() batch_size = X.shape[0] batch_log = {'size': batch_size} with torch.no_grad(): z = self.net_q(X) pred = self.net_c(z) criterion = nn.CrossEntropyLoss() # the cross-entropy loss combines the 1)LogSoftmax() and 2) Negative log-likelihood loss NLLLoss loss_clf = criterion(pred, y) batch_log['loss'] = loss_clf.item() batch_log['loss_clf'] = loss_clf.item() return batch_log
class Model(NetInterface): @classmethod def add_arguments(cls, parser): parser.add_argument('--net_N', type=int, default=128, help="Number of neurons in hidden layers") parser.add_argument('--x_dim', type=int, default=1000, help="dimension of input features") parser.add_argument('--z_dim', type=int, default=100, help="dimension of hidden variables") parser.add_argument('--n_classes', type=int, default=1, help="number of nodes for classification") parser.add_argument('--p_drop', type=float, default=0.2, help="probability of dropout") return parser, set() def __init__(self, opt, logger): super().__init__(opt, logger) self.net_N = opt.net_N self.x_dim = opt.x_dim self.z_dim = opt.z_dim self.n_classes = opt.n_classes self.p_drop = opt.p_drop self.modality = opt.modality self.n_modality = len(self.modality) # init networks self.net_q = Q_net(self.net_N, self.x_dim, self.z_dim, self.p_drop) self.net_p = P_net(self.net_N, self.x_dim, self.z_dim, self.p_drop) self.net_c = C_net(self.net_N, self.z_dim, self.n_classes, self.p_drop) self._nets = [self.net_q, self.net_p, self.net_c] # optimizers self.optimizer_q = self.optimizer(self.net_q.parameters(), lr=opt.lr, **self.optimizer_params) self.optimizer_p = self.optimizer(self.net_p.parameters(), lr=opt.lr, **self.optimizer_params) self.optimizer_c = self.optimizer(self.net_c.parameters(), lr=opt.lr, **self.optimizer_params) self._optimizers = [ self.optimizer_q, self.optimizer_p, self.optimizer_c ] #schedulers self.scheduler_q = self.scheduler(self.optimizer_q, **self.scheduler_params) self.scheduler_p = self.scheduler(self.optimizer_p, **self.scheduler_params) self.scheduler_c = self.scheduler(self.optimizer_c, **self.scheduler_params) self._schedulers = [ self.scheduler_q, self.scheduler_p, self.scheduler_c ] # general self.opt = opt self._metrics = [ 'loss_mse', 'loss_clf' ] # log the autoencoder loss and the classification loss if opt.log_time: self._metrics += ['t_recon', 't_clf'] # init variables #self.input_names = ['features', 'labels'] self.init_vars(add_path=True) # init weights self.init_weight(self.net_q) self.init_weight(self.net_p) self.init_weight(self.net_c) def __str__(self): s = "Concat + Autoencoder" return s def _train_on_batch(self, epoch, batch_idx, batch): net_q, net_p, net_c = self.net_q, self.net_p, self.net_c opt_q, opt_p, opt_c = self.optimizer_q, self.optimizer_p, self.optimizer_c net_q.train() net_p.train() net_c.train() # read in X X_list = [None] * self.n_modality for i in range(self.n_modality): X_list[i] = batch[self.modality[i]] X_list[i] = X_list[i].view(X_list[i].shape[0], -1) X_list = [tmp.cuda() for tmp in X_list] # Concatenate X before feeding into Autoencoder X = torch.cat(X_list, dim=1).cuda() y = batch['labels'].cuda() batch_size = X.shape[0] batch_log = {'size': batch_size} # Stage 1: train Q and P with reconstruction loss net_q.zero_grad() net_p.zero_grad() for p in net_q.parameters(): p.requires_grad = True for p in net_p.parameters(): p.requires_grad = True for p in net_c.parameters(): p.requires_grad = False t0 = time() z = net_q(X) X_recon = net_p(z) loss_mse = nn.functional.mse_loss(X_recon, X) # Mean square error loss_mse.backward() opt_q.step() opt_p.step() t_recon = time() - t0 batch_log['loss_mse'] = loss_mse.item() # Stage 2: train Q and C with classification loss net_q.zero_grad() net_c.zero_grad() for p in net_q.parameters(): p.requires_grad = True for p in net_c.parameters(): p.requires_grad = True for p in net_p.parameters(): p.requires_grad = False t0 = time() z = net_q(X) pred = net_c(z) # calculate loss with cross entropy loss criterion = nn.CrossEntropyLoss() loss_clf = criterion(pred, y) loss_clf.backward() opt_q.step() opt_c.step() t_clf = time() - t0 batch_log['loss_clf'] = loss_clf.item() if self.opt.log_time: batch_log['t_recon'] = t_recon batch_log['t_clf'] = t_clf return batch_log def _vali_on_batch(self, epoch, batch_idx, batch): self.net_q.eval() self.net_c.eval() # read in X X_list = [None] * self.n_modality for i in range(self.n_modality): X_list[i] = batch[self.modality[i]] X_list[i] = X_list[i].view(X_list[i].shape[0], -1) X_list = [tmp.cuda() for tmp in X_list] # concatenate X before feeding into Autoencoder X = torch.cat(X_list, dim=1).cuda() y = batch['labels'].cuda() batch_size = X.shape[0] batch_log = {'size': batch_size} with torch.no_grad(): z = self.net_q(X) pred = self.net_c(z) criterion = nn.CrossEntropyLoss() loss_clf = criterion(pred, y) batch_log['loss_clf'] = loss_clf.item() batch_log['loss'] = loss_clf.item() return batch_log
class Model(NetInterface): @classmethod def add_arguments(cls, parser): parser.add_argument('--net_N', type=int, default=128, help="Number of neurons in hidden layers") parser.add_argument('--x_dim', type=int, default=1000, help="dimension of input features") parser.add_argument('--z_dim', type=int, default=100, help="dimension of hidden variables") parser.add_argument('--n_classes', type=int, default=1, help="number of nodes for classification") parser.add_argument('--p_drop', type=float, default=0.2, help="probability of dropout") parser.add_argument('--clf_weights', type=float, nargs='+', default=0.7, help="classification weight for each class") return parser, set() def __init__(self, opt, logger): super().__init__(opt, logger) self.net_N = opt.net_N self.x_dim = opt.x_dim self.z_dim = opt.z_dim self.n_classes = opt.n_classes self.p_drop = opt.p_drop self.modality = opt.modality self.clf_weights = opt.clf_weights # init networks self.net_q = Q_net(self.net_N, self.x_dim, self.z_dim, self.p_drop) self.net_p = P_net(self.net_N, self.x_dim, self.z_dim, self.p_drop) self.net_c = C_net(self.net_N, self.z_dim, self.n_classes, self.p_drop) self._nets = [self.net_q, self.net_p, self.net_c] # optimizers # self.optimizer and self.optimizer_params have been initialized in netinterface self.optimizer_q = self.optimizer(self.net_q.parameters(), lr=opt.lr, **self.optimizer_params) self.optimizer_p = self.optimizer(self.net_p.parameters(), lr=opt.lr, **self.optimizer_params) self.optimizer_c = self.optimizer(self.net_c.parameters(), lr=opt.lr, **self.optimizer_params) self._optimizers = [ self.optimizer_q, self.optimizer_p, self.optimizer_c ] # schedulers # self.scheduler and self.scheduler_params have been initialized in netinterface self.scheduler_q = self.scheduler(self.optimizer_q, **self.scheduler_params) self.scheduler_p = self.scheduler(self.optimizer_p, **self.scheduler_params) self.scheduler_c = self.scheduler(self.optimizer_c, **self.scheduler_params) self._schedulers = [ self.scheduler_q, self.scheduler_p, self.scheduler_c ] # general self.opt = opt self._metrics = [ 'loss_mse', 'loss_survival', 'c_index' ] # log the autoencoder loss and the classification loss if opt.log_time: self._metrics += ['t_recon', 't_survival'] # init variables self.init_vars(add_path=True) # init weights self.init_weight(self.net_q) self.init_weight(self.net_p) self.init_weight(self.net_c) def __str__(self): s = "Autoencoder" return s def _train_on_batch(self, epoch, batch_idx, batch): net_q, net_p, net_c = self.net_q, self.net_p, self.net_c opt_q, opt_p, opt_c = self.optimizer_q, self.optimizer_p, self.optimizer_c net_q.train() net_p.train() net_c.train() X = batch[self.modality].cuda() survival_time = batch['days'].cuda() survival_event = batch['event'].cuda() batch_size = X.shape[0] batch_log = {'size': batch_size} # Stage 1: train Q and P with reconstruction loss net_q.zero_grad() net_p.zero_grad() for p in net_q.parameters(): p.requires_grad = True for p in net_p.parameters(): p.requires_grad = True for p in net_c.parameters(): p.requires_grad = False t0 = time() z = net_q(X) X_recon = net_p(z) loss_mse = nn.functional.mse_loss(X_recon, X) # Mean square error loss_mse.backward() opt_q.step() opt_p.step() t_recon = time() - t0 batch_log['loss_mse'] = loss_mse.item() # Stage 2: train Q and C with classification loss if not survival_event.sum( 0): # skip the batch if all instances are negative batch_log['loss_survial'] = torch.Tensor([float('nan')]) return batch_log net_q.zero_grad() net_c.zero_grad() for p in net_q.parameters(): p.requires_grad = True for p in net_c.parameters(): p.requires_grad = True for p in net_p.parameters(): p.requires_grad = False t0 = time() z = net_q(X) pred = net_c(z) loss_survival = neg_par_log_likelihood(pred, survival_time, survival_event) loss_survival.backward() opt_q.step() opt_c.step() t_survival = time() - t0 c_index = CIndex(pred, survival_event, survival_time) batch_log['loss_survival'] = loss_survival.item() batch_log['c_index'] = c_index if self.opt.log_time: batch_log['t_recon'] = t_recon batch_log['t_survival'] = t_survival return batch_log def _vali_on_batch(self, epoch, batch_idx, batch): self.net_q.eval() self.net_c.eval() # load the data X = batch[self.modality].cuda() survival_time = batch['days'].cuda() survival_event = batch['event'].cuda() batch_size = X.shape[0] batch_log = {'size': batch_size} if not survival_event.sum( 0): # skip the batch if all instances are negative batch_log['loss_survival'] = torch.Tensor([float('nan')]) return batch_log with torch.no_grad(): z = self.net_q(X) pred = self.net_c(z) loss_survival = neg_par_log_likelihood(pred, survival_time, survival_event) c_index = CIndex(pred, survival_event, survival_time) batch_log['loss'] = loss_survival.item() batch_log['loss_survival'] = loss_survival.item() batch_log['c_index'] = c_index return batch_log