def __init__(self, opt, logger):
        super().__init__(opt, logger)
        self.net_N = opt.net_N
        self.x_dim = opt.x_dim
        self.z_dim = opt.z_dim
        self.n_classes = opt.n_classes
        self.p_drop = opt.p_drop
        self.modality = opt.modality
        self.n_modality = len(self.modality)
        self.clf_weights = opt.clf_weights
        assert len(self.x_dim) == self.n_modality  # check the length of x_dim

        # init networks
        self.net_q = [None] * self.n_modality
        for i in range(self.n_modality):
            self.net_q[i] = Q_net(self.net_N, self.x_dim[i], self.z_dim,
                                  self.p_drop)

        self.net_c = C_net(self.net_N, self.z_dim * self.n_modality,
                           self.n_classes, self.p_drop)
        self._nets = self.net_q + [self.net_c]

        # optimizers
        self.optimizer_q = [None] * self.n_modality

        for i in range(self.n_modality):
            self.optimizer_q[i] = self.optimizer(self.net_q[i].parameters(),
                                                 lr=opt.lr,
                                                 **self.optimizer_params)
        self.optimizer_c = self.optimizer(self.net_c.parameters(),
                                          lr=opt.lr,
                                          **self.optimizer_params)
        self._optimizers = self.optimizer_q + [self.optimizer_c]

        # scheduler
        self.scheduler_q = [None] * self.n_modality
        for i in range(self.n_modality):
            self.scheduler_q[i] = self.scheduler(self.optimizer_q[i],
                                                 **self.scheduler_params)
        self.scheduler_c = self.scheduler(self.optimizer_c,
                                          **self.scheduler_params)
        self._schedulers = self.scheduler_q + [self.scheduler_c]

        # general
        self.opt = opt
        self._metrics = [
            'loss_survival', 'c_index'
        ]  # log the autoencoder loss and the classification loss
        if opt.log_time:
            self._metrics += ['t_survival']

        # init variables
        #self.input_names = ['features', 'labels']
        self.init_vars(add_path=True)

        # init weights
        for i in range(self.n_modality):
            self.init_weight(self.net_q[i])
        self.init_weight(self.net_c)
    def __init__(self, opt, logger):
        super().__init__(opt, logger)
        self.net_N = opt.net_N
        self.x_dim = opt.x_dim
        self.z_dim = opt.z_dim
        self.n_classes = opt.n_classes
        self.p_drop = opt.p_drop
        self.modality = opt.modality
        self.clf_weights = opt.clf_weights

        # init networks
        self.net_q = Q_net(self.net_N, self.x_dim, self.z_dim, self.p_drop)
        self.net_p = P_net(self.net_N, self.x_dim, self.z_dim, self.p_drop)
        self.net_c = C_net(self.net_N, self.z_dim, self.n_classes, self.p_drop)
        self._nets = [self.net_q, self.net_p, self.net_c]

        # optimizers
        # self.optimizer and self.optimizer_params have been initialized in netinterface
        self.optimizer_q = self.optimizer(self.net_q.parameters(),
                                          lr=opt.lr,
                                          **self.optimizer_params)
        self.optimizer_p = self.optimizer(self.net_p.parameters(),
                                          lr=opt.lr,
                                          **self.optimizer_params)
        self.optimizer_c = self.optimizer(self.net_c.parameters(),
                                          lr=opt.lr,
                                          **self.optimizer_params)
        self._optimizers = [
            self.optimizer_q, self.optimizer_p, self.optimizer_c
        ]

        # schedulers
        # self.scheduler and self.scheduler_params have been initialized in netinterface
        self.scheduler_q = self.scheduler(self.optimizer_q,
                                          **self.scheduler_params)
        self.scheduler_p = self.scheduler(self.optimizer_p,
                                          **self.scheduler_params)
        self.scheduler_c = self.scheduler(self.optimizer_c,
                                          **self.scheduler_params)
        self._schedulers = [
            self.scheduler_q, self.scheduler_p, self.scheduler_c
        ]

        # general
        self.opt = opt
        self._metrics = [
            'loss_mse', 'loss_survival', 'c_index'
        ]  # log the autoencoder loss and the classification loss
        if opt.log_time:
            self._metrics += ['t_recon', 't_survival']

        # init variables
        self.init_vars(add_path=True)

        # init weights
        self.init_weight(self.net_q)
        self.init_weight(self.net_p)
        self.init_weight(self.net_c)
    def __init__(self, opt, logger):
        super().__init__(opt, logger)
        self.net_N = opt.net_N
        self.x_dim = opt.x_dim
        self.z_dim = opt.z_dim
        self.n_classes = opt.n_classes
        self.p_drop = opt.p_drop
        self.modality = opt.modality
        self.n_modality = len(self.modality)

        # init networks
        self.net_q = Q_net(self.net_N, self.x_dim, self.z_dim, self.p_drop)
        self.net_p = P_net(self.net_N, self.x_dim, self.z_dim, self.p_drop)
        self.net_c = C_net(self.net_N, self.z_dim, self.n_classes, self.p_drop)
        self._nets = [self.net_q, self.net_p, self.net_c]
        # optimizers
        self.optimizer_q = self.optimizer(self.net_q.parameters(),
                                          lr=opt.lr,
                                          **self.optimizer_params)
        self.optimizer_p = self.optimizer(self.net_p.parameters(),
                                          lr=opt.lr,
                                          **self.optimizer_params)
        self.optimizer_c = self.optimizer(self.net_c.parameters(),
                                          lr=opt.lr,
                                          **self.optimizer_params)
        self._optimizers = [
            self.optimizer_q, self.optimizer_p, self.optimizer_c
        ]
        #schedulers
        self.scheduler_q = self.scheduler(self.optimizer_q,
                                          **self.scheduler_params)
        self.scheduler_p = self.scheduler(self.optimizer_p,
                                          **self.scheduler_params)
        self.scheduler_c = self.scheduler(self.optimizer_c,
                                          **self.scheduler_params)
        self._schedulers = [
            self.scheduler_q, self.scheduler_p, self.scheduler_c
        ]

        # general
        self.opt = opt
        self._metrics = [
            'loss_mse', 'loss_clf'
        ]  # log the autoencoder loss and the classification loss
        if opt.log_time:
            self._metrics += ['t_recon', 't_clf']

        # init variables
        #self.input_names = ['features', 'labels']
        self.init_vars(add_path=True)

        # init weights
        self.init_weight(self.net_q)
        self.init_weight(self.net_p)
        self.init_weight(self.net_c)
    def __init__(self, opt, logger):
        super().__init__(opt, logger)
        self.net_N = opt.net_N
        self.x_dim = opt.x_dim
        self.z_dim = opt.z_dim
        self.n_classes = opt.n_classes
        self.p_drop = opt.p_drop
        self.modality = opt.modality[
            0]  # get the fisrt element since changed this argument to list
        self.clf_weights = opt.clf_weights

        # init networks
        self.net_q = Q_net(self.net_N, self.x_dim, self.z_dim, self.p_drop)
        self.net_c = C_net(self.net_N, self.z_dim, self.n_classes, self.p_drop)
        self._nets = [self.net_q, self.net_c]

        # optimizers
        # self.optimizer and self.optimizer_params have been initialized in netinterface
        self.optimizer_q = self.optimizer(self.net_q.parameters(),
                                          lr=opt.lr,
                                          **self.optimizer_params)
        self.optimizer_c = self.optimizer(self.net_c.parameters(),
                                          lr=opt.lr,
                                          **self.optimizer_params)
        self._optimizers = [self.optimizer_q, self.optimizer_c]

        # schedulers
        # self.scheduler and self.scheduler_params have been initialized in netinterface
        self.scheduler_q = self.scheduler(self.optimizer_q,
                                          **self.scheduler_params)
        self.scheduler_c = self.scheduler(self.optimizer_c,
                                          **self.scheduler_params)
        self._schedulers = [self.scheduler_q, self.scheduler_c]

        # general
        self.opt = opt
        self._metrics = [
            'loss_clf'
        ]  # log the autoencoder loss and the classification loss
        if opt.log_time:
            self._metrics += ['t_clf']

        # init variables
        self.init_vars(add_path=True)

        # init weights
        self.init_weight(self.net_q)
        self.init_weight(self.net_c)
예제 #5
0
class Model(NetInterface):
    @classmethod
    def add_arguments(cls, parser):
        parser.add_argument('--net_N',
                            type=int,
                            default=128,
                            help="Number of neurons in hidden layers")
        parser.add_argument(
            '--x_dim',
            type=int,
            nargs='+',
            default=1000,
            help=
            "dimension of input features, take one or more arguments for integration"
        )
        parser.add_argument('--z_dim',
                            type=int,
                            default=100,
                            help="dimension of hidden variables")
        parser.add_argument('--n_classes',
                            type=int,
                            default=1,
                            help="number of nodes for classification")
        parser.add_argument('--p_drop',
                            type=float,
                            default=0.2,
                            help="probability of dropout")
        return parser, set()

    def __init__(self, opt, logger):
        super().__init__(opt, logger)
        self.net_N = opt.net_N
        self.x_dim = opt.x_dim
        self.z_dim = opt.z_dim
        self.n_classes = opt.n_classes
        self.p_drop = opt.p_drop
        self.modality = opt.modality
        self.n_modality = len(self.modality)
        assert len(self.x_dim) == self.n_modality  # check the length of x_dim

        # init networks
        self.net_q = [None] * self.n_modality
        self.net_p = [None] * self.n_modality
        for i in range(self.n_modality):
            self.net_q[i] = Q_net(self.net_N, self.x_dim[i], self.z_dim,
                                  self.p_drop)
            self.net_p[i] = P_net(self.net_N, self.x_dim[i], self.z_dim,
                                  self.p_drop)

        self.net_c = C_net(self.net_N, self.z_dim * self.n_modality,
                           self.n_classes, self.p_drop)
        self._nets = self.net_q + self.net_p + [self.net_c]
        # optimizers
        self.optimizer_q = [None] * self.n_modality
        self.optimizer_p = [None] * self.n_modality

        for i in range(self.n_modality):
            self.optimizer_q[i] = self.optimizer(self.net_q[i].parameters(),
                                                 lr=opt.lr,
                                                 **self.optimizer_params)
            self.optimizer_p[i] = self.optimizer(self.net_p[i].parameters(),
                                                 lr=opt.lr,
                                                 **self.optimizer_params)
        self.optimizer_c = self.optimizer(self.net_c.parameters(),
                                          lr=opt.lr,
                                          **self.optimizer_params)
        self._optimizers = self.optimizer_q + self.optimizer_p + [
            self.optimizer_c
        ]
        # scheduler
        self.scheduler_q = [None] * self.n_modality
        self.scheduler_p = [None] * self.n_modality
        for i in range(self.n_modality):
            self.scheduler_q[i] = self.scheduler(self.optimizer_q[i],
                                                 **self.scheduler_params)
            self.scheduler_p[i] = self.scheduler(self.optimizer_p[i],
                                                 **self.scheduler_params)
        self.scheduler_c = self.scheduler(self.optimizer_c,
                                          **self.scheduler_params)
        self._schedulers = self.scheduler_q + self.scheduler_p + [
            self.scheduler_c
        ]

        # general
        self.opt = opt
        self._metrics = [
            'loss_mse', 'loss_clf'
        ]  # log the autoencoder loss and the classification loss
        if opt.log_time:
            self._metrics += ['t_mse', 't_clf']

        # init variables
        #self.input_names = ['features', 'labels']
        self.init_vars(add_path=True)

        # init weights
        for i in range(self.n_modality):
            self.init_weight(self.net_q[i])
            self.init_weight(self.net_p[i])
        self.init_weight(self.net_c)

    def __str__(self):
        s = "MNIST Simulated Autoencoder Concat"
        return s

    def _train_on_batch(self, epoch, batch_idx, batch):
        net_q, net_p, net_c = self.net_q, self.net_p, self.net_c
        opt_q, opt_p, opt_c = self.optimizer_q, self.optimizer_p, self.optimizer_c
        for i in range(self.n_modality):
            net_q[i].train()
            net_p[i].train()
        net_c.train()

        X_list = [None] * self.n_modality
        X_recon_list = [None] * self.n_modality
        z_list = [None] * self.n_modality
        loss_mse_list = [None] * self.n_modality
        # read in training data
        for i in range(self.n_modality):
            X_list[i] = batch[self.modality[i]]
            X_list[i] = X_list[i].view(X_list[i].shape[0], -1)
        X_list = [tmp.cuda() for tmp in X_list]
        y = batch['labels'].cuda()
        batch_size = X_list[0].shape[0]
        batch_log = {'size': batch_size}

        ###################################################
        # Stage 1: train Q and P with reconstruction loss #
        ###################################################
        for i in range(self.n_modality):
            net_q[i].zero_grad()
            net_p[i].zero_grad()
            for p in net_q[i].parameters():
                p.requires_grad = True
            for p in net_p[i].parameters():
                p.requires_grad = True
        for p in net_c.parameters():
            p.requires_grad = False

        t0 = time()
        for i in range(self.n_modality):
            z_list[i] = net_q[i](X_list[i])
            X_recon_list[i] = net_p[i](z_list[i])
            loss_mse = nn.functional.mse_loss(X_recon_list[i],
                                              X_list[i])  # Mean square error
            loss_mse_list[i] = loss_mse.item()
            loss_mse.backward()
            opt_q[i].step()
            opt_p[i].step()
        t_mse = time() - t0
        batch_log['loss_mse'] = sum(loss_mse_list)

        ####################################################
        # Stage 2: train Q and C with classification loss  #
        ####################################################
        for i in range(self.n_modality):
            net_q[i].zero_grad()
            for p in net_q[i].parameters():
                p.requires_grad = True
            for p in net_p[i].parameters():
                p.requires_grad = False
        net_c.zero_grad()
        for p in net_c.parameters():
            p.requires_grad = True

        t0 = time()
        for i in range(self.n_modality):
            z_list[i] = net_q[i](X_list[i])
        #z_combined = torch.mean(torch.stack(z_list), dim=0) # get the mean of z_list
        z_combined = torch.cat(z_list, dim=1)
        pred = net_c(z_combined)

        criterion = nn.CrossEntropyLoss()
        loss_clf = criterion(pred, y)
        loss_clf.backward()
        for i in range(self.n_modality):
            opt_q[i].step()
        opt_c.step()
        t_clf = time() - t0
        batch_log['loss_clf'] = loss_clf.item()

        if self.opt.log_time:
            batch_log['t_mse'] = t_mse
            batch_log['t_clf'] = t_clf
        return batch_log

    def _vali_on_batch(self, epoch, batch_idx, batch):
        for i in range(self.n_modality):
            self.net_q[i].eval()
        self.net_c.eval()

        X_list = [None] * self.n_modality
        z_list = [None] * self.n_modality
        # read in training data
        for i in range(self.n_modality):
            X_list[i] = batch[self.modality[i]]
            X_list[i] = X_list[i].view(X_list[i].shape[0], -1)

        X_list = [tmp.cuda() for tmp in X_list]
        y = batch['labels'].cuda()
        batch_size = X_list[0].shape[0]
        batch_log = {'size': batch_size}
        with torch.no_grad():
            for i in range(self.n_modality):
                z_list[i] = self.net_q[i](X_list[i])
            z_combined = torch.cat(z_list, dim=1)
            pred = self.net_c(z_combined)

        criterion = nn.CrossEntropyLoss()
        loss_clf = criterion(pred, y)
        batch_log['loss_clf'] = loss_clf.item()
        batch_log['loss'] = loss_clf.item()
        return batch_log
class Model(NetInterface):
    @classmethod
    def add_arguments(cls, parser):
        parser.add_argument('--net_N',
                            type=int,
                            default=128,
                            help="Number of neurons in hidden layers")
        parser.add_argument('--x_dim',
                            type=int,
                            nargs='+',
                            default=1000,
                            help="dimension of input features")
        parser.add_argument('--z_dim',
                            type=int,
                            default=100,
                            help="dimension of hidden variables")
        parser.add_argument('--n_classes',
                            type=int,
                            default=1,
                            help="number of nodes for classification")
        parser.add_argument('--p_drop',
                            type=float,
                            default=0.2,
                            help="probability of dropout")
        parser.add_argument('--clf_weights',
                            type=float,
                            nargs='+',
                            default=0.7,
                            help="classification weight for each class")
        return parser, set()

    def __init__(self, opt, logger):
        super().__init__(opt, logger)
        self.net_N = opt.net_N
        self.x_dim = opt.x_dim
        self.z_dim = opt.z_dim
        self.n_classes = opt.n_classes
        self.p_drop = opt.p_drop
        self.modality = opt.modality
        self.n_modality = len(self.modality)
        self.clf_weights = opt.clf_weights
        assert len(self.x_dim) == self.n_modality  # check the length of x_dim

        # init networks
        self.net_q = [None] * self.n_modality
        for i in range(self.n_modality):
            self.net_q[i] = Q_net(self.net_N, self.x_dim[i], self.z_dim,
                                  self.p_drop)

        self.net_c = C_net(self.net_N, self.z_dim * self.n_modality,
                           self.n_classes, self.p_drop)
        self._nets = self.net_q + [self.net_c]

        # optimizers
        self.optimizer_q = [None] * self.n_modality

        for i in range(self.n_modality):
            self.optimizer_q[i] = self.optimizer(self.net_q[i].parameters(),
                                                 lr=opt.lr,
                                                 **self.optimizer_params)
        self.optimizer_c = self.optimizer(self.net_c.parameters(),
                                          lr=opt.lr,
                                          **self.optimizer_params)
        self._optimizers = self.optimizer_q + [self.optimizer_c]

        # scheduler
        self.scheduler_q = [None] * self.n_modality
        for i in range(self.n_modality):
            self.scheduler_q[i] = self.scheduler(self.optimizer_q[i],
                                                 **self.scheduler_params)
        self.scheduler_c = self.scheduler(self.optimizer_c,
                                          **self.scheduler_params)
        self._schedulers = self.scheduler_q + [self.scheduler_c]

        # general
        self.opt = opt
        self._metrics = [
            'loss_survival', 'c_index'
        ]  # log the autoencoder loss and the classification loss
        if opt.log_time:
            self._metrics += ['t_survival']

        # init variables
        #self.input_names = ['features', 'labels']
        self.init_vars(add_path=True)

        # init weights
        for i in range(self.n_modality):
            self.init_weight(self.net_q[i])
        self.init_weight(self.net_c)

    def __str__(self):
        s = "Autoencoder Concat"
        return s

    def _train_on_batch(self, epoch, batch_idx, batch):
        net_q, net_c = self.net_q, self.net_c
        opt_q, opt_c = self.optimizer_q, self.optimizer_c
        for i in range(self.n_modality):
            net_q[i].train()
        net_c.train()

        X_list = [None] * self.n_modality
        z_list = [None] * self.n_modality
        # read in training data
        for i in range(self.n_modality):
            X_list[i] = batch[self.modality[i]]
        X_list = [tmp.cuda() for tmp in X_list]
        survival_time = batch['days'].cuda()
        survival_event = batch['event'].cuda()
        batch_size = X_list[0].shape[0]
        batch_log = {'size': batch_size}

        ####################################################
        # Stage 2: train Q and C with classification loss  #
        ####################################################
        if not survival_event.sum(
                0):  # skip the batch if all instances are negative
            batch_log['loss_survival'] = torch.Tensor([float('nan')])
            return batch_log
        for i in range(self.n_modality):
            net_q[i].zero_grad()
            for p in net_q[i].parameters():
                p.requires_grad = True
        net_c.zero_grad()
        for p in net_c.parameters():
            p.requires_grad = True

        t0 = time()
        for i in range(self.n_modality):
            z_list[i] = net_q[i](X_list[i])
        z_combined = torch.cat(z_list, dim=1)
        pred = net_c(z_combined)

        loss_survival = neg_par_log_likelihood(pred, survival_time,
                                               survival_event)
        loss_survival.backward()
        for i in range(self.n_modality):
            opt_q[i].step()
        opt_c.step()
        t_survival = time() - t0
        c_index = CIndex(pred, survival_event, survival_time)
        batch_log['loss_survival'] = loss_survival.item()
        batch_log['c_index'] = c_index

        if self.opt.log_time:
            batch_log['t_survival'] = t_survival
        return batch_log

    def _vali_on_batch(self, epoch, batch_idx, batch):
        for i in range(self.n_modality):
            self.net_q[i].eval()
        self.net_c.eval()

        X_list = [None] * self.n_modality
        z_list = [None] * self.n_modality
        # read in training data
        for i in range(self.n_modality):
            X_list[i] = batch[self.modality[i]]
        X_list = [tmp.cuda() for tmp in X_list]
        survival_time = batch['days'].cuda()
        survival_event = batch['event'].cuda()

        batch_size = X_list[0].shape[0]
        batch_log = {'size': batch_size}
        if not survival_event.sum(
                0):  # skip the batch if all instances are negative
            batch_log['loss_survival'] = torch.Tensor([float('nan')])
            return batch_log

        with torch.no_grad():
            for i in range(self.n_modality):
                z_list[i] = self.net_q[i](X_list[i])
            z_combined = torch.cat(z_list, dim=1)
            pred = self.net_c(z_combined)

        loss_survival = neg_par_log_likelihood(pred, survival_time,
                                               survival_event)
        c_index = CIndex(pred, survival_event, survival_time)
        batch_log['loss'] = loss_survival.item()
        batch_log['loss_survival'] = loss_survival.item()
        batch_log['c_index'] = c_index
        return batch_log
class Model(NetInterface):
    @classmethod
    def add_arguments(cls, parser):
        parser.add_argument('--net_N',
                            type=int,
                            default=128,
                            help="Number of neurons in hidden layers")
        parser.add_argument('--x_dim',
                            type=int,
                            default=1000,
                            help="dimension of input features")
        parser.add_argument('--z_dim',
                            type=int,
                            default=100,
                            help="dimension of hidden variables")
        parser.add_argument('--n_classes',
                            type=int,
                            default=1,
                            help="number of nodes for classification")
        parser.add_argument('--p_drop',
                            type=float,
                            default=0.2,
                            help="probability of dropout")
        return parser, set()

    def __init__(self, opt, logger):
        super().__init__(opt, logger)
        self.net_N = opt.net_N
        self.x_dim = opt.x_dim
        self.z_dim = opt.z_dim
        self.n_classes = opt.n_classes
        self.p_drop = opt.p_drop
        self.modality = opt.modality
        self.n_modality = len(self.modality)

        # init networks
        self.net_q = Q_net(self.net_N, self.x_dim, self.z_dim, self.p_drop)
        self.net_p = P_net(self.net_N, self.x_dim, self.z_dim, self.p_drop)
        self.net_c = C_net(self.net_N, self.z_dim, self.n_classes, self.p_drop)
        self._nets = [self.net_q, self.net_p, self.net_c]
        # optimizers
        self.optimizer_q = self.optimizer(self.net_q.parameters(),
                                          lr=opt.lr,
                                          **self.optimizer_params)
        self.optimizer_p = self.optimizer(self.net_p.parameters(),
                                          lr=opt.lr,
                                          **self.optimizer_params)
        self.optimizer_c = self.optimizer(self.net_c.parameters(),
                                          lr=opt.lr,
                                          **self.optimizer_params)
        self._optimizers = [
            self.optimizer_q, self.optimizer_p, self.optimizer_c
        ]
        #schedulers
        self.scheduler_q = self.scheduler(self.optimizer_q,
                                          **self.scheduler_params)
        self.scheduler_p = self.scheduler(self.optimizer_p,
                                          **self.scheduler_params)
        self.scheduler_c = self.scheduler(self.optimizer_c,
                                          **self.scheduler_params)
        self._schedulers = [
            self.scheduler_q, self.scheduler_p, self.scheduler_c
        ]

        # general
        self.opt = opt
        self._metrics = [
            'loss_mse', 'loss_clf'
        ]  # log the autoencoder loss and the classification loss
        if opt.log_time:
            self._metrics += ['t_recon', 't_clf']

        # init variables
        #self.input_names = ['features', 'labels']
        self.init_vars(add_path=True)

        # init weights
        self.init_weight(self.net_q)
        self.init_weight(self.net_p)
        self.init_weight(self.net_c)

    def __str__(self):
        s = "Concat + Autoencoder"
        return s

    def _train_on_batch(self, epoch, batch_idx, batch):
        net_q, net_p, net_c = self.net_q, self.net_p, self.net_c
        opt_q, opt_p, opt_c = self.optimizer_q, self.optimizer_p, self.optimizer_c
        net_q.train()
        net_p.train()
        net_c.train()

        # read in X
        X_list = [None] * self.n_modality
        for i in range(self.n_modality):
            X_list[i] = batch[self.modality[i]]
            X_list[i] = X_list[i].view(X_list[i].shape[0], -1)
        X_list = [tmp.cuda() for tmp in X_list]

        # Concatenate X before feeding into Autoencoder
        X = torch.cat(X_list, dim=1).cuda()
        y = batch['labels'].cuda()

        batch_size = X.shape[0]
        batch_log = {'size': batch_size}

        # Stage 1: train Q and P with reconstruction loss
        net_q.zero_grad()
        net_p.zero_grad()
        for p in net_q.parameters():
            p.requires_grad = True
        for p in net_p.parameters():
            p.requires_grad = True
        for p in net_c.parameters():
            p.requires_grad = False
        t0 = time()
        z = net_q(X)
        X_recon = net_p(z)
        loss_mse = nn.functional.mse_loss(X_recon, X)  # Mean square error
        loss_mse.backward()
        opt_q.step()
        opt_p.step()
        t_recon = time() - t0
        batch_log['loss_mse'] = loss_mse.item()

        # Stage 2: train Q and C with classification loss
        net_q.zero_grad()
        net_c.zero_grad()
        for p in net_q.parameters():
            p.requires_grad = True
        for p in net_c.parameters():
            p.requires_grad = True
        for p in net_p.parameters():
            p.requires_grad = False
        t0 = time()
        z = net_q(X)
        pred = net_c(z)

        # calculate loss with cross entropy loss
        criterion = nn.CrossEntropyLoss()
        loss_clf = criterion(pred, y)
        loss_clf.backward()
        opt_q.step()
        opt_c.step()
        t_clf = time() - t0
        batch_log['loss_clf'] = loss_clf.item()
        if self.opt.log_time:
            batch_log['t_recon'] = t_recon
            batch_log['t_clf'] = t_clf
        return batch_log

    def _vali_on_batch(self, epoch, batch_idx, batch):
        self.net_q.eval()
        self.net_c.eval()

        # read in X
        X_list = [None] * self.n_modality
        for i in range(self.n_modality):
            X_list[i] = batch[self.modality[i]]
            X_list[i] = X_list[i].view(X_list[i].shape[0], -1)
        X_list = [tmp.cuda() for tmp in X_list]
        # concatenate X before feeding into Autoencoder
        X = torch.cat(X_list, dim=1).cuda()

        y = batch['labels'].cuda()
        batch_size = X.shape[0]
        batch_log = {'size': batch_size}
        with torch.no_grad():
            z = self.net_q(X)
            pred = self.net_c(z)

        criterion = nn.CrossEntropyLoss()
        loss_clf = criterion(pred, y)
        batch_log['loss_clf'] = loss_clf.item()
        batch_log['loss'] = loss_clf.item()
        return batch_log
class Model(NetInterface):
    @classmethod
    def add_arguments(cls, parser):
        parser.add_argument('--net_N',
                            type=int,
                            default=128,
                            help="Number of neurons in hidden layers")
        parser.add_argument('--x_dim',
                            type=int,
                            default=1000,
                            help="dimension of input features")
        parser.add_argument('--z_dim',
                            type=int,
                            default=100,
                            help="dimension of hidden variables")
        parser.add_argument('--n_classes',
                            type=int,
                            default=1,
                            help="number of nodes for classification")
        parser.add_argument('--p_drop',
                            type=float,
                            default=0.2,
                            help="probability of dropout")
        parser.add_argument('--clf_weights',
                            type=float,
                            nargs='+',
                            default=0.7,
                            help="classification weight for each class")
        return parser, set()

    def __init__(self, opt, logger):
        super().__init__(opt, logger)
        self.net_N = opt.net_N
        self.x_dim = opt.x_dim
        self.z_dim = opt.z_dim
        self.n_classes = opt.n_classes
        self.p_drop = opt.p_drop
        self.modality = opt.modality[
            0]  # get the fisrt element since changed this argument to list
        self.clf_weights = opt.clf_weights

        # init networks
        self.net_q = Q_net(self.net_N, self.x_dim, self.z_dim, self.p_drop)
        self.net_c = C_net(self.net_N, self.z_dim, self.n_classes, self.p_drop)
        self._nets = [self.net_q, self.net_c]

        # optimizers
        # self.optimizer and self.optimizer_params have been initialized in netinterface
        self.optimizer_q = self.optimizer(self.net_q.parameters(),
                                          lr=opt.lr,
                                          **self.optimizer_params)
        self.optimizer_c = self.optimizer(self.net_c.parameters(),
                                          lr=opt.lr,
                                          **self.optimizer_params)
        self._optimizers = [self.optimizer_q, self.optimizer_c]

        # schedulers
        # self.scheduler and self.scheduler_params have been initialized in netinterface
        self.scheduler_q = self.scheduler(self.optimizer_q,
                                          **self.scheduler_params)
        self.scheduler_c = self.scheduler(self.optimizer_c,
                                          **self.scheduler_params)
        self._schedulers = [self.scheduler_q, self.scheduler_c]

        # general
        self.opt = opt
        self._metrics = [
            'loss_clf'
        ]  # log the autoencoder loss and the classification loss
        if opt.log_time:
            self._metrics += ['t_clf']

        # init variables
        self.init_vars(add_path=True)

        # init weights
        self.init_weight(self.net_q)
        self.init_weight(self.net_c)

    def __str__(self):
        s = "Autoencoder"
        return s

    def _train_on_batch(self, epoch, batch_idx, batch):
        net_q, net_c = self.net_q, self.net_c
        opt_q, opt_c = self.optimizer_q, self.optimizer_c
        net_q.train()
        net_c.train()

        X = batch[self.modality].cuda()
        #print(X.shape) # (batchsize, 1, 28, 28)
        # Flatten the X from 28x28 to 784
        X = X.view(X.shape[0], -1)
        #print(X.shape) # (batchsize, 784)
        y = batch['labels'].cuda()
        batch_size = X.shape[0]
        batch_log = {'size': batch_size}

        net_q.zero_grad()
        net_c.zero_grad()
        for p in net_q.parameters():
            p.requires_grad = True
        for p in net_c.parameters():
            p.requires_grad = True
        t0 = time()
        z = net_q(X)
        pred = net_c(z)

        criterion = nn.CrossEntropyLoss()
        # the cross-entropy loss combines the 1)LogSoftmax() and 2) Negative log-likelihood loss NLLLoss
        loss_clf = criterion(pred, y)
        loss_clf.backward()
        opt_q.step()
        opt_c.step()
        t_clf = time() - t0
        batch_log['loss_clf'] = loss_clf.item()
        if self.opt.log_time:
            batch_log['t_clf'] = t_clf
        return batch_log

    def _vali_on_batch(self, epoch, batch_idx, batch):
        self.net_q.eval()
        self.net_c.eval()
        X = batch[self.modality].cuda()
        X = X.view(X.shape[0], -1)  # flatten the X
        y = batch['labels'].cuda()
        batch_size = X.shape[0]
        batch_log = {'size': batch_size}
        with torch.no_grad():
            z = self.net_q(X)
            pred = self.net_c(z)

        criterion = nn.CrossEntropyLoss()
        # the cross-entropy loss combines the 1)LogSoftmax() and 2) Negative log-likelihood loss NLLLoss
        loss_clf = criterion(pred, y)
        batch_log['loss'] = loss_clf.item()
        batch_log['loss_clf'] = loss_clf.item()
        return batch_log
class Model(NetInterface):
    @classmethod
    def add_arguments(cls, parser):
        parser.add_argument('--net_N',
                            type=int,
                            default=128,
                            help="Number of neurons in hidden layers")
        parser.add_argument('--x_dim',
                            type=int,
                            default=1000,
                            help="dimension of input features")
        parser.add_argument('--z_dim',
                            type=int,
                            default=100,
                            help="dimension of hidden variables")
        parser.add_argument('--n_classes',
                            type=int,
                            default=1,
                            help="number of nodes for classification")
        parser.add_argument('--p_drop',
                            type=float,
                            default=0.2,
                            help="probability of dropout")
        parser.add_argument('--clf_weights',
                            type=float,
                            nargs='+',
                            default=0.7,
                            help="classification weight for each class")
        return parser, set()

    def __init__(self, opt, logger):
        super().__init__(opt, logger)
        self.net_N = opt.net_N
        self.x_dim = opt.x_dim
        self.z_dim = opt.z_dim
        self.n_classes = opt.n_classes
        self.p_drop = opt.p_drop
        self.modality = opt.modality
        self.clf_weights = opt.clf_weights

        # init networks
        self.net_q = Q_net(self.net_N, self.x_dim, self.z_dim, self.p_drop)
        self.net_p = P_net(self.net_N, self.x_dim, self.z_dim, self.p_drop)
        self.net_c = C_net(self.net_N, self.z_dim, self.n_classes, self.p_drop)
        self._nets = [self.net_q, self.net_p, self.net_c]

        # optimizers
        # self.optimizer and self.optimizer_params have been initialized in netinterface
        self.optimizer_q = self.optimizer(self.net_q.parameters(),
                                          lr=opt.lr,
                                          **self.optimizer_params)
        self.optimizer_p = self.optimizer(self.net_p.parameters(),
                                          lr=opt.lr,
                                          **self.optimizer_params)
        self.optimizer_c = self.optimizer(self.net_c.parameters(),
                                          lr=opt.lr,
                                          **self.optimizer_params)
        self._optimizers = [
            self.optimizer_q, self.optimizer_p, self.optimizer_c
        ]

        # schedulers
        # self.scheduler and self.scheduler_params have been initialized in netinterface
        self.scheduler_q = self.scheduler(self.optimizer_q,
                                          **self.scheduler_params)
        self.scheduler_p = self.scheduler(self.optimizer_p,
                                          **self.scheduler_params)
        self.scheduler_c = self.scheduler(self.optimizer_c,
                                          **self.scheduler_params)
        self._schedulers = [
            self.scheduler_q, self.scheduler_p, self.scheduler_c
        ]

        # general
        self.opt = opt
        self._metrics = [
            'loss_mse', 'loss_survival', 'c_index'
        ]  # log the autoencoder loss and the classification loss
        if opt.log_time:
            self._metrics += ['t_recon', 't_survival']

        # init variables
        self.init_vars(add_path=True)

        # init weights
        self.init_weight(self.net_q)
        self.init_weight(self.net_p)
        self.init_weight(self.net_c)

    def __str__(self):
        s = "Autoencoder"
        return s

    def _train_on_batch(self, epoch, batch_idx, batch):
        net_q, net_p, net_c = self.net_q, self.net_p, self.net_c
        opt_q, opt_p, opt_c = self.optimizer_q, self.optimizer_p, self.optimizer_c
        net_q.train()
        net_p.train()
        net_c.train()

        X = batch[self.modality].cuda()
        survival_time = batch['days'].cuda()
        survival_event = batch['event'].cuda()
        batch_size = X.shape[0]
        batch_log = {'size': batch_size}

        # Stage 1: train Q and P with reconstruction loss
        net_q.zero_grad()
        net_p.zero_grad()
        for p in net_q.parameters():
            p.requires_grad = True
        for p in net_p.parameters():
            p.requires_grad = True
        for p in net_c.parameters():
            p.requires_grad = False
        t0 = time()
        z = net_q(X)
        X_recon = net_p(z)
        loss_mse = nn.functional.mse_loss(X_recon, X)  # Mean square error
        loss_mse.backward()
        opt_q.step()
        opt_p.step()
        t_recon = time() - t0
        batch_log['loss_mse'] = loss_mse.item()

        # Stage 2: train Q and C with classification loss
        if not survival_event.sum(
                0):  # skip the batch if all instances are negative
            batch_log['loss_survial'] = torch.Tensor([float('nan')])
            return batch_log
        net_q.zero_grad()
        net_c.zero_grad()
        for p in net_q.parameters():
            p.requires_grad = True
        for p in net_c.parameters():
            p.requires_grad = True
        for p in net_p.parameters():
            p.requires_grad = False
        t0 = time()
        z = net_q(X)
        pred = net_c(z)
        loss_survival = neg_par_log_likelihood(pred, survival_time,
                                               survival_event)
        loss_survival.backward()
        opt_q.step()
        opt_c.step()
        t_survival = time() - t0
        c_index = CIndex(pred, survival_event, survival_time)
        batch_log['loss_survival'] = loss_survival.item()
        batch_log['c_index'] = c_index
        if self.opt.log_time:
            batch_log['t_recon'] = t_recon
            batch_log['t_survival'] = t_survival
        return batch_log

    def _vali_on_batch(self, epoch, batch_idx, batch):
        self.net_q.eval()
        self.net_c.eval()
        # load the data
        X = batch[self.modality].cuda()
        survival_time = batch['days'].cuda()
        survival_event = batch['event'].cuda()
        batch_size = X.shape[0]
        batch_log = {'size': batch_size}
        if not survival_event.sum(
                0):  # skip the batch if all instances are negative
            batch_log['loss_survival'] = torch.Tensor([float('nan')])
            return batch_log
        with torch.no_grad():
            z = self.net_q(X)
            pred = self.net_c(z)
        loss_survival = neg_par_log_likelihood(pred, survival_time,
                                               survival_event)
        c_index = CIndex(pred, survival_event, survival_time)
        batch_log['loss'] = loss_survival.item()
        batch_log['loss_survival'] = loss_survival.item()
        batch_log['c_index'] = c_index
        return batch_log