def __init__(self, opt, logger):
        super().__init__(opt, logger)
        self.net_N = opt.net_N
        self.x_dim = opt.x_dim
        self.z_dim = opt.z_dim
        self.n_classes = opt.n_classes
        self.p_drop = opt.p_drop
        self.modality = opt.modality
        self.clf_weights = opt.clf_weights

        # init networks
        self.net_q = Q_net(self.net_N, self.x_dim, self.z_dim, self.p_drop)
        self.net_p = P_net(self.net_N, self.x_dim, self.z_dim, self.p_drop)
        self.net_c = C_net(self.net_N, self.z_dim, self.n_classes, self.p_drop)
        self._nets = [self.net_q, self.net_p, self.net_c]

        # optimizers
        # self.optimizer and self.optimizer_params have been initialized in netinterface
        self.optimizer_q = self.optimizer(self.net_q.parameters(),
                                          lr=opt.lr,
                                          **self.optimizer_params)
        self.optimizer_p = self.optimizer(self.net_p.parameters(),
                                          lr=opt.lr,
                                          **self.optimizer_params)
        self.optimizer_c = self.optimizer(self.net_c.parameters(),
                                          lr=opt.lr,
                                          **self.optimizer_params)
        self._optimizers = [
            self.optimizer_q, self.optimizer_p, self.optimizer_c
        ]

        # schedulers
        # self.scheduler and self.scheduler_params have been initialized in netinterface
        self.scheduler_q = self.scheduler(self.optimizer_q,
                                          **self.scheduler_params)
        self.scheduler_p = self.scheduler(self.optimizer_p,
                                          **self.scheduler_params)
        self.scheduler_c = self.scheduler(self.optimizer_c,
                                          **self.scheduler_params)
        self._schedulers = [
            self.scheduler_q, self.scheduler_p, self.scheduler_c
        ]

        # general
        self.opt = opt
        self._metrics = [
            'loss_mse', 'loss_survival', 'c_index'
        ]  # log the autoencoder loss and the classification loss
        if opt.log_time:
            self._metrics += ['t_recon', 't_survival']

        # init variables
        self.init_vars(add_path=True)

        # init weights
        self.init_weight(self.net_q)
        self.init_weight(self.net_p)
        self.init_weight(self.net_c)
    def __init__(self, opt, logger):
        super().__init__(opt, logger)
        self.net_N = opt.net_N
        self.x_dim = opt.x_dim
        self.z_dim = opt.z_dim
        self.n_classes = opt.n_classes
        self.p_drop = opt.p_drop
        self.modality = opt.modality
        self.n_modality = len(self.modality)

        # init networks
        self.net_q = Q_net(self.net_N, self.x_dim, self.z_dim, self.p_drop)
        self.net_p = P_net(self.net_N, self.x_dim, self.z_dim, self.p_drop)
        self.net_c = C_net(self.net_N, self.z_dim, self.n_classes, self.p_drop)
        self._nets = [self.net_q, self.net_p, self.net_c]
        # optimizers
        self.optimizer_q = self.optimizer(self.net_q.parameters(),
                                          lr=opt.lr,
                                          **self.optimizer_params)
        self.optimizer_p = self.optimizer(self.net_p.parameters(),
                                          lr=opt.lr,
                                          **self.optimizer_params)
        self.optimizer_c = self.optimizer(self.net_c.parameters(),
                                          lr=opt.lr,
                                          **self.optimizer_params)
        self._optimizers = [
            self.optimizer_q, self.optimizer_p, self.optimizer_c
        ]
        #schedulers
        self.scheduler_q = self.scheduler(self.optimizer_q,
                                          **self.scheduler_params)
        self.scheduler_p = self.scheduler(self.optimizer_p,
                                          **self.scheduler_params)
        self.scheduler_c = self.scheduler(self.optimizer_c,
                                          **self.scheduler_params)
        self._schedulers = [
            self.scheduler_q, self.scheduler_p, self.scheduler_c
        ]

        # general
        self.opt = opt
        self._metrics = [
            'loss_mse', 'loss_clf'
        ]  # log the autoencoder loss and the classification loss
        if opt.log_time:
            self._metrics += ['t_recon', 't_clf']

        # init variables
        #self.input_names = ['features', 'labels']
        self.init_vars(add_path=True)

        # init weights
        self.init_weight(self.net_q)
        self.init_weight(self.net_p)
        self.init_weight(self.net_c)
    def __init__(self, opt, logger):
        super().__init__(opt, logger)
        self.net_N = opt.net_N
        self.x_dim = opt.x_dim
        self.z_dim = opt.z_dim
        self.n_classes = opt.n_classes
        self.p_drop = opt.p_drop
        self.modality = opt.modality[
            0]  # get the fisrt element since changed this argument to list
        self.clf_weights = opt.clf_weights

        # init networks
        self.net_q = Q_net(self.net_N, self.x_dim, self.z_dim, self.p_drop)
        self.net_c = C_net(self.net_N, self.z_dim, self.n_classes, self.p_drop)
        self._nets = [self.net_q, self.net_c]

        # optimizers
        # self.optimizer and self.optimizer_params have been initialized in netinterface
        self.optimizer_q = self.optimizer(self.net_q.parameters(),
                                          lr=opt.lr,
                                          **self.optimizer_params)
        self.optimizer_c = self.optimizer(self.net_c.parameters(),
                                          lr=opt.lr,
                                          **self.optimizer_params)
        self._optimizers = [self.optimizer_q, self.optimizer_c]

        # schedulers
        # self.scheduler and self.scheduler_params have been initialized in netinterface
        self.scheduler_q = self.scheduler(self.optimizer_q,
                                          **self.scheduler_params)
        self.scheduler_c = self.scheduler(self.optimizer_c,
                                          **self.scheduler_params)
        self._schedulers = [self.scheduler_q, self.scheduler_c]

        # general
        self.opt = opt
        self._metrics = [
            'loss_clf'
        ]  # log the autoencoder loss and the classification loss
        if opt.log_time:
            self._metrics += ['t_clf']

        # init variables
        self.init_vars(add_path=True)

        # init weights
        self.init_weight(self.net_q)
        self.init_weight(self.net_c)
    def __init__(self, opt, logger):
        super().__init__(opt, logger)
        self.net_N = opt.net_N
        self.x_dim = opt.x_dim
        self.z_dim = opt.z_dim
        self.n_classes = opt.n_classes
        self.p_drop = opt.p_drop
        self.modality = opt.modality
        self.n_modality = len(self.modality)
        self.clf_weights = opt.clf_weights
        assert len(self.x_dim) == self.n_modality  # check the length of x_dim

        # init networks
        self.net_q = [None] * self.n_modality
        for i in range(self.n_modality):
            self.net_q[i] = Q_net(self.net_N, self.x_dim[i], self.z_dim,
                                  self.p_drop)

        self.net_c = C_net(self.net_N, self.z_dim * self.n_modality,
                           self.n_classes, self.p_drop)
        self._nets = self.net_q + [self.net_c]

        # optimizers
        self.optimizer_q = [None] * self.n_modality

        for i in range(self.n_modality):
            self.optimizer_q[i] = self.optimizer(self.net_q[i].parameters(),
                                                 lr=opt.lr,
                                                 **self.optimizer_params)
        self.optimizer_c = self.optimizer(self.net_c.parameters(),
                                          lr=opt.lr,
                                          **self.optimizer_params)
        self._optimizers = self.optimizer_q + [self.optimizer_c]

        # scheduler
        self.scheduler_q = [None] * self.n_modality
        for i in range(self.n_modality):
            self.scheduler_q[i] = self.scheduler(self.optimizer_q[i],
                                                 **self.scheduler_params)
        self.scheduler_c = self.scheduler(self.optimizer_c,
                                          **self.scheduler_params)
        self._schedulers = self.scheduler_q + [self.scheduler_c]

        # general
        self.opt = opt
        self._metrics = [
            'loss_survival', 'c_index'
        ]  # log the autoencoder loss and the classification loss
        if opt.log_time:
            self._metrics += ['t_survival']

        # init variables
        #self.input_names = ['features', 'labels']
        self.init_vars(add_path=True)

        # init weights
        for i in range(self.n_modality):
            self.init_weight(self.net_q[i])
        self.init_weight(self.net_c)
Пример #5
0
 def __init__(self, args):
     self.args = args
     self.policy = [Q_net(args) for _ in range(args.n_agents)]
     self.hyperNet = HyperNet(args)
     self.policy_target = [copy.deepcopy(p) for p in self.policy]
     self.hyperNet_target = copy.deepcopy(self.hyperNet)
     self.replayBuffer = ReplayBuffer(args)
     self.preference_pool = Preference(args)
     policy_param = [policy.parameters() for policy in self.policy]
     self.optim = torch.optim.Adam(itertools.chain(
         *policy_param, self.hyperNet.parameters()),
                                   lr=self.args.learning_rate)
     self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optim,
                                                         step_size=10,
                                                         gamma=0.95,
                                                         last_epoch=-1)
class Model(NetInterface):
    @classmethod
    def add_arguments(cls, parser):
        parser.add_argument('--net_N',
                            type=int,
                            default=128,
                            help="Number of neurons in hidden layers")
        parser.add_argument('--x_dim',
                            type=int,
                            default=1000,
                            help="dimension of input features")
        parser.add_argument('--z_dim',
                            type=int,
                            default=100,
                            help="dimension of hidden variables")
        parser.add_argument('--n_classes',
                            type=int,
                            default=1,
                            help="number of nodes for classification")
        parser.add_argument('--p_drop',
                            type=float,
                            default=0.2,
                            help="probability of dropout")
        return parser, set()

    def __init__(self, opt, logger):
        super().__init__(opt, logger)
        self.net_N = opt.net_N
        self.x_dim = opt.x_dim
        self.z_dim = opt.z_dim
        self.n_classes = opt.n_classes
        self.p_drop = opt.p_drop
        self.modality = opt.modality
        self.n_modality = len(self.modality)

        # init networks
        self.net_q = Q_net(self.net_N, self.x_dim, self.z_dim, self.p_drop)
        self.net_p = P_net(self.net_N, self.x_dim, self.z_dim, self.p_drop)
        self.net_c = C_net(self.net_N, self.z_dim, self.n_classes, self.p_drop)
        self._nets = [self.net_q, self.net_p, self.net_c]
        # optimizers
        self.optimizer_q = self.optimizer(self.net_q.parameters(),
                                          lr=opt.lr,
                                          **self.optimizer_params)
        self.optimizer_p = self.optimizer(self.net_p.parameters(),
                                          lr=opt.lr,
                                          **self.optimizer_params)
        self.optimizer_c = self.optimizer(self.net_c.parameters(),
                                          lr=opt.lr,
                                          **self.optimizer_params)
        self._optimizers = [
            self.optimizer_q, self.optimizer_p, self.optimizer_c
        ]
        #schedulers
        self.scheduler_q = self.scheduler(self.optimizer_q,
                                          **self.scheduler_params)
        self.scheduler_p = self.scheduler(self.optimizer_p,
                                          **self.scheduler_params)
        self.scheduler_c = self.scheduler(self.optimizer_c,
                                          **self.scheduler_params)
        self._schedulers = [
            self.scheduler_q, self.scheduler_p, self.scheduler_c
        ]

        # general
        self.opt = opt
        self._metrics = [
            'loss_mse', 'loss_clf'
        ]  # log the autoencoder loss and the classification loss
        if opt.log_time:
            self._metrics += ['t_recon', 't_clf']

        # init variables
        #self.input_names = ['features', 'labels']
        self.init_vars(add_path=True)

        # init weights
        self.init_weight(self.net_q)
        self.init_weight(self.net_p)
        self.init_weight(self.net_c)

    def __str__(self):
        s = "Concat + Autoencoder"
        return s

    def _train_on_batch(self, epoch, batch_idx, batch):
        net_q, net_p, net_c = self.net_q, self.net_p, self.net_c
        opt_q, opt_p, opt_c = self.optimizer_q, self.optimizer_p, self.optimizer_c
        net_q.train()
        net_p.train()
        net_c.train()

        # read in X
        X_list = [None] * self.n_modality
        for i in range(self.n_modality):
            X_list[i] = batch[self.modality[i]]
            X_list[i] = X_list[i].view(X_list[i].shape[0], -1)
        X_list = [tmp.cuda() for tmp in X_list]

        # Concatenate X before feeding into Autoencoder
        X = torch.cat(X_list, dim=1).cuda()
        y = batch['labels'].cuda()

        batch_size = X.shape[0]
        batch_log = {'size': batch_size}

        # Stage 1: train Q and P with reconstruction loss
        net_q.zero_grad()
        net_p.zero_grad()
        for p in net_q.parameters():
            p.requires_grad = True
        for p in net_p.parameters():
            p.requires_grad = True
        for p in net_c.parameters():
            p.requires_grad = False
        t0 = time()
        z = net_q(X)
        X_recon = net_p(z)
        loss_mse = nn.functional.mse_loss(X_recon, X)  # Mean square error
        loss_mse.backward()
        opt_q.step()
        opt_p.step()
        t_recon = time() - t0
        batch_log['loss_mse'] = loss_mse.item()

        # Stage 2: train Q and C with classification loss
        net_q.zero_grad()
        net_c.zero_grad()
        for p in net_q.parameters():
            p.requires_grad = True
        for p in net_c.parameters():
            p.requires_grad = True
        for p in net_p.parameters():
            p.requires_grad = False
        t0 = time()
        z = net_q(X)
        pred = net_c(z)

        # calculate loss with cross entropy loss
        criterion = nn.CrossEntropyLoss()
        loss_clf = criterion(pred, y)
        loss_clf.backward()
        opt_q.step()
        opt_c.step()
        t_clf = time() - t0
        batch_log['loss_clf'] = loss_clf.item()
        if self.opt.log_time:
            batch_log['t_recon'] = t_recon
            batch_log['t_clf'] = t_clf
        return batch_log

    def _vali_on_batch(self, epoch, batch_idx, batch):
        self.net_q.eval()
        self.net_c.eval()

        # read in X
        X_list = [None] * self.n_modality
        for i in range(self.n_modality):
            X_list[i] = batch[self.modality[i]]
            X_list[i] = X_list[i].view(X_list[i].shape[0], -1)
        X_list = [tmp.cuda() for tmp in X_list]
        # concatenate X before feeding into Autoencoder
        X = torch.cat(X_list, dim=1).cuda()

        y = batch['labels'].cuda()
        batch_size = X.shape[0]
        batch_log = {'size': batch_size}
        with torch.no_grad():
            z = self.net_q(X)
            pred = self.net_c(z)

        criterion = nn.CrossEntropyLoss()
        loss_clf = criterion(pred, y)
        batch_log['loss_clf'] = loss_clf.item()
        batch_log['loss'] = loss_clf.item()
        return batch_log
class Model(NetInterface):
    @classmethod
    def add_arguments(cls, parser):
        parser.add_argument('--net_N',
                            type=int,
                            default=128,
                            help="Number of neurons in hidden layers")
        parser.add_argument('--x_dim',
                            type=int,
                            default=1000,
                            help="dimension of input features")
        parser.add_argument('--z_dim',
                            type=int,
                            default=100,
                            help="dimension of hidden variables")
        parser.add_argument('--n_classes',
                            type=int,
                            default=1,
                            help="number of nodes for classification")
        parser.add_argument('--p_drop',
                            type=float,
                            default=0.2,
                            help="probability of dropout")
        parser.add_argument('--clf_weights',
                            type=float,
                            nargs='+',
                            default=0.7,
                            help="classification weight for each class")
        return parser, set()

    def __init__(self, opt, logger):
        super().__init__(opt, logger)
        self.net_N = opt.net_N
        self.x_dim = opt.x_dim
        self.z_dim = opt.z_dim
        self.n_classes = opt.n_classes
        self.p_drop = opt.p_drop
        self.modality = opt.modality[
            0]  # get the fisrt element since changed this argument to list
        self.clf_weights = opt.clf_weights

        # init networks
        self.net_q = Q_net(self.net_N, self.x_dim, self.z_dim, self.p_drop)
        self.net_c = C_net(self.net_N, self.z_dim, self.n_classes, self.p_drop)
        self._nets = [self.net_q, self.net_c]

        # optimizers
        # self.optimizer and self.optimizer_params have been initialized in netinterface
        self.optimizer_q = self.optimizer(self.net_q.parameters(),
                                          lr=opt.lr,
                                          **self.optimizer_params)
        self.optimizer_c = self.optimizer(self.net_c.parameters(),
                                          lr=opt.lr,
                                          **self.optimizer_params)
        self._optimizers = [self.optimizer_q, self.optimizer_c]

        # schedulers
        # self.scheduler and self.scheduler_params have been initialized in netinterface
        self.scheduler_q = self.scheduler(self.optimizer_q,
                                          **self.scheduler_params)
        self.scheduler_c = self.scheduler(self.optimizer_c,
                                          **self.scheduler_params)
        self._schedulers = [self.scheduler_q, self.scheduler_c]

        # general
        self.opt = opt
        self._metrics = [
            'loss_clf'
        ]  # log the autoencoder loss and the classification loss
        if opt.log_time:
            self._metrics += ['t_clf']

        # init variables
        self.init_vars(add_path=True)

        # init weights
        self.init_weight(self.net_q)
        self.init_weight(self.net_c)

    def __str__(self):
        s = "Autoencoder"
        return s

    def _train_on_batch(self, epoch, batch_idx, batch):
        net_q, net_c = self.net_q, self.net_c
        opt_q, opt_c = self.optimizer_q, self.optimizer_c
        net_q.train()
        net_c.train()

        X = batch[self.modality].cuda()
        #print(X.shape) # (batchsize, 1, 28, 28)
        # Flatten the X from 28x28 to 784
        X = X.view(X.shape[0], -1)
        #print(X.shape) # (batchsize, 784)
        y = batch['labels'].cuda()
        batch_size = X.shape[0]
        batch_log = {'size': batch_size}

        net_q.zero_grad()
        net_c.zero_grad()
        for p in net_q.parameters():
            p.requires_grad = True
        for p in net_c.parameters():
            p.requires_grad = True
        t0 = time()
        z = net_q(X)
        pred = net_c(z)

        criterion = nn.CrossEntropyLoss()
        # the cross-entropy loss combines the 1)LogSoftmax() and 2) Negative log-likelihood loss NLLLoss
        loss_clf = criterion(pred, y)
        loss_clf.backward()
        opt_q.step()
        opt_c.step()
        t_clf = time() - t0
        batch_log['loss_clf'] = loss_clf.item()
        if self.opt.log_time:
            batch_log['t_clf'] = t_clf
        return batch_log

    def _vali_on_batch(self, epoch, batch_idx, batch):
        self.net_q.eval()
        self.net_c.eval()
        X = batch[self.modality].cuda()
        X = X.view(X.shape[0], -1)  # flatten the X
        y = batch['labels'].cuda()
        batch_size = X.shape[0]
        batch_log = {'size': batch_size}
        with torch.no_grad():
            z = self.net_q(X)
            pred = self.net_c(z)

        criterion = nn.CrossEntropyLoss()
        # the cross-entropy loss combines the 1)LogSoftmax() and 2) Negative log-likelihood loss NLLLoss
        loss_clf = criterion(pred, y)
        batch_log['loss'] = loss_clf.item()
        batch_log['loss_clf'] = loss_clf.item()
        return batch_log
class Model(NetInterface):
    @classmethod
    def add_arguments(cls, parser):
        parser.add_argument('--net_N',
                            type=int,
                            default=128,
                            help="Number of neurons in hidden layers")
        parser.add_argument('--x_dim',
                            type=int,
                            default=1000,
                            help="dimension of input features")
        parser.add_argument('--z_dim',
                            type=int,
                            default=100,
                            help="dimension of hidden variables")
        parser.add_argument('--n_classes',
                            type=int,
                            default=1,
                            help="number of nodes for classification")
        parser.add_argument('--p_drop',
                            type=float,
                            default=0.2,
                            help="probability of dropout")
        parser.add_argument('--clf_weights',
                            type=float,
                            nargs='+',
                            default=0.7,
                            help="classification weight for each class")
        return parser, set()

    def __init__(self, opt, logger):
        super().__init__(opt, logger)
        self.net_N = opt.net_N
        self.x_dim = opt.x_dim
        self.z_dim = opt.z_dim
        self.n_classes = opt.n_classes
        self.p_drop = opt.p_drop
        self.modality = opt.modality
        self.clf_weights = opt.clf_weights

        # init networks
        self.net_q = Q_net(self.net_N, self.x_dim, self.z_dim, self.p_drop)
        self.net_p = P_net(self.net_N, self.x_dim, self.z_dim, self.p_drop)
        self.net_c = C_net(self.net_N, self.z_dim, self.n_classes, self.p_drop)
        self._nets = [self.net_q, self.net_p, self.net_c]

        # optimizers
        # self.optimizer and self.optimizer_params have been initialized in netinterface
        self.optimizer_q = self.optimizer(self.net_q.parameters(),
                                          lr=opt.lr,
                                          **self.optimizer_params)
        self.optimizer_p = self.optimizer(self.net_p.parameters(),
                                          lr=opt.lr,
                                          **self.optimizer_params)
        self.optimizer_c = self.optimizer(self.net_c.parameters(),
                                          lr=opt.lr,
                                          **self.optimizer_params)
        self._optimizers = [
            self.optimizer_q, self.optimizer_p, self.optimizer_c
        ]

        # schedulers
        # self.scheduler and self.scheduler_params have been initialized in netinterface
        self.scheduler_q = self.scheduler(self.optimizer_q,
                                          **self.scheduler_params)
        self.scheduler_p = self.scheduler(self.optimizer_p,
                                          **self.scheduler_params)
        self.scheduler_c = self.scheduler(self.optimizer_c,
                                          **self.scheduler_params)
        self._schedulers = [
            self.scheduler_q, self.scheduler_p, self.scheduler_c
        ]

        # general
        self.opt = opt
        self._metrics = [
            'loss_mse', 'loss_survival', 'c_index'
        ]  # log the autoencoder loss and the classification loss
        if opt.log_time:
            self._metrics += ['t_recon', 't_survival']

        # init variables
        self.init_vars(add_path=True)

        # init weights
        self.init_weight(self.net_q)
        self.init_weight(self.net_p)
        self.init_weight(self.net_c)

    def __str__(self):
        s = "Autoencoder"
        return s

    def _train_on_batch(self, epoch, batch_idx, batch):
        net_q, net_p, net_c = self.net_q, self.net_p, self.net_c
        opt_q, opt_p, opt_c = self.optimizer_q, self.optimizer_p, self.optimizer_c
        net_q.train()
        net_p.train()
        net_c.train()

        X = batch[self.modality].cuda()
        survival_time = batch['days'].cuda()
        survival_event = batch['event'].cuda()
        batch_size = X.shape[0]
        batch_log = {'size': batch_size}

        # Stage 1: train Q and P with reconstruction loss
        net_q.zero_grad()
        net_p.zero_grad()
        for p in net_q.parameters():
            p.requires_grad = True
        for p in net_p.parameters():
            p.requires_grad = True
        for p in net_c.parameters():
            p.requires_grad = False
        t0 = time()
        z = net_q(X)
        X_recon = net_p(z)
        loss_mse = nn.functional.mse_loss(X_recon, X)  # Mean square error
        loss_mse.backward()
        opt_q.step()
        opt_p.step()
        t_recon = time() - t0
        batch_log['loss_mse'] = loss_mse.item()

        # Stage 2: train Q and C with classification loss
        if not survival_event.sum(
                0):  # skip the batch if all instances are negative
            batch_log['loss_survial'] = torch.Tensor([float('nan')])
            return batch_log
        net_q.zero_grad()
        net_c.zero_grad()
        for p in net_q.parameters():
            p.requires_grad = True
        for p in net_c.parameters():
            p.requires_grad = True
        for p in net_p.parameters():
            p.requires_grad = False
        t0 = time()
        z = net_q(X)
        pred = net_c(z)
        loss_survival = neg_par_log_likelihood(pred, survival_time,
                                               survival_event)
        loss_survival.backward()
        opt_q.step()
        opt_c.step()
        t_survival = time() - t0
        c_index = CIndex(pred, survival_event, survival_time)
        batch_log['loss_survival'] = loss_survival.item()
        batch_log['c_index'] = c_index
        if self.opt.log_time:
            batch_log['t_recon'] = t_recon
            batch_log['t_survival'] = t_survival
        return batch_log

    def _vali_on_batch(self, epoch, batch_idx, batch):
        self.net_q.eval()
        self.net_c.eval()
        # load the data
        X = batch[self.modality].cuda()
        survival_time = batch['days'].cuda()
        survival_event = batch['event'].cuda()
        batch_size = X.shape[0]
        batch_log = {'size': batch_size}
        if not survival_event.sum(
                0):  # skip the batch if all instances are negative
            batch_log['loss_survival'] = torch.Tensor([float('nan')])
            return batch_log
        with torch.no_grad():
            z = self.net_q(X)
            pred = self.net_c(z)
        loss_survival = neg_par_log_likelihood(pred, survival_time,
                                               survival_event)
        c_index = CIndex(pred, survival_event, survival_time)
        batch_log['loss'] = loss_survival.item()
        batch_log['loss_survival'] = loss_survival.item()
        batch_log['c_index'] = c_index
        return batch_log