Beispiel #1
0
class BiomeAE():
    def __init__(self, args):
        if args.model in ["BiomeAEL0"]:
            self.mlp_type = "L0"
        else:
            self.mlp_type = None

        self.model_alias = args.model_alias
        self.model= args.model
        self.snap_loc = os.path.join(args.vis_dir, "snap.pt")

        #tl.configure("runs/ds.{}".format(model_alias))
        #tl.log_value(model_alias, 0)

        """ no stat file needed for now
        stat_alias = 'obj_DataStat+%s_%s' % (args.dataset_name, args.dataset_subset)
        stat_path = os.path.join(
            output_dir, '%s.pkl' % (stat_alias)
        )
        with open(stat_path,'rb') as sf:
            data_stats = pickle.load(sf)
        """

        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_num
        self.predictor = None

        torch.manual_seed(args.seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(args.seed)

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


    def get_transformation(self):
        return None

    def loss_fn(self, recon_x, x, mean, log_var):
        if self.model in ["BiomeAE","BiomeAESnip"]:
            mseloss = torch.nn.MSELoss()
            return torch.sqrt(mseloss(recon_x, x))
        if self.model in ["BiomeAEL0"]:
            mseloss = torch.nn.MSELoss()
            return torch.sqrt(mseloss(recon_x, x))+self.predictor.regularization()
        elif self.model =="BiomeVAE":
            BCE = torch.nn.functional.binary_cross_entropy(
                recon_x.view(-1, 28*28), x.view(-1, 28*28), reduction='sum')
            KLD = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
            return (BCE + KLD) / x.size(0)

    def param_l0(self):
        return self.predictor.param_l0()

    def init_fit(self, X1_train, X2_train, y_train, X1_val, X2_val, y_val, args, ):
        self.train_loader = get_dataloader (X1_train, X2_train, y_train, args.batch_size)
        self.test_loader = get_dataloader(X1_val, X2_val, y_val, args.batch_size)

        self.predictor = AE(
            encoder_layer_sizes=[X1_train.shape[1]],
            latent_size=args.latent_size,
            decoder_layer_sizes=[X2_train.shape[1]],
            activation=args.activation,
            batch_norm= args.batch_norm,
            dropout=args.dropout,
            mlp_type=self.mlp_type,
            conditional=args.conditional,
            num_labels=10 if args.conditional else 0).to(self.device)
        self.optimizer = torch.optim.Adam(self.predictor.parameters(), lr=args.learning_rate)
        self.scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=0.8)

    def train(self, args):
        if args.contr:
            print("Loading from ", self.snap_loc)
            loaded_model_para = torch.load(self.snap_loc)
            self.predictor.load_state_dict(loaded_model_para)

        t = 0
        logs = defaultdict(list)
        iterations_per_epoch = len(self.train_loader.dataset) / args.batch_size
        num_iterations = int(iterations_per_epoch * args.epochs)
        for epoch in range(args.epochs):

            tracker_epoch = defaultdict(lambda: defaultdict(dict))

            for iteration, (x1, x2, y) in enumerate(self.train_loader):
                t+=1

                x1, x2, y = x1.to(self.device), x2.to(self.device), y.to(self.device)

                if args.conditional:
                    x2_hat, z, mean, log_var = self.predictor(x1, y)
                else:
                    x2_hat, z, mean, log_var = self.predictor(x1)

                for i, yi in enumerate(y):
                    id = len(tracker_epoch)
                    tracker_epoch[id]['x'] = z[i, 0].item()
                    tracker_epoch[id]['y'] = z[i, 1].item()
                    tracker_epoch[id]['label'] = yi.item()

                loss = self.loss_fn(x2_hat, x2, mean, log_var)

                self.optimizer.zero_grad()
                loss.backward()
                if (t + 1) % int(num_iterations / 10) == 0:
                    self.scheduler.step()
                self.optimizer.step()

                #enforce non-negative
                if args.nonneg_weight:
                    for layer in self.predictor.modules():
                        if isinstance(layer, nn.Linear):
                            layer.weight.data.clamp_(0.0)


                logs['loss'].append(loss.item())

                if iteration % args.print_every == 0 or iteration == len(self.train_loader)-1:
                    print("Epoch {:02d}/{:02d} Batch {:04d}/{:d}, Loss {:9.4f}".format(
                        epoch, args.epochs, iteration, len(self.train_loader)-1, loss.item()))
                    if args.model =="VAE":
                        if args.conditional:
                            c = torch.arange(0, 10).long().unsqueeze(1)
                            x = self.predictor.inference(n=c.size(0), c=c)
                        else:
                            x = self.predictor.inference(n=10)

        if not args.contr:
            print("Saving to ", self.snap_loc)
            torch.save(self.predictor.state_dict(), self.snap_loc)

    def fit(self,X1_train, X2_train, y_train, X1_val, X2_val, y_val, args,):
        self.init_fit(X1_train, X2_train, y_train, X1_val, X2_val, y_val, args)
        self.train(args)

    def get_graph(self):
        """
        return nodes and weights
        :return:
        """
        nodes = []
        weights = []
        for l, layer in enumerate(self.predictor.modules()):
            if isinstance(layer, nn.Linear):
                lin_layer =layer
                nodes.append(["%s"%(x) for x in list(range(lin_layer.in_features))])
                weights.append(lin_layer.weight.detach().cpu().numpy().T)
        nodes.append(["%s"%(x) for x in list(range(lin_layer.out_features))]) #last linear layer

        return (nodes, weights)

    def predict(self,X1_val, X2_val, y_val, args):
        #Batch test
        x1, x2, y = torch.FloatTensor(X1_val).to(self.device), torch.FloatTensor(X2_val).to(self.device), torch.FloatTensor(y_val).to(self.device)
        if args.conditional:
            x2_hat, z, mean, log_var = self.predictor(x1, y)
        else:
            x2_hat, z, mean, log_var = self.predictor(x1)
        val_loss = self.loss_fn( x2_hat, x2, mean, log_var)
        print("val_loss: {:9.4f}", val_loss.item())
        return x2_hat.detach().cpu().numpy()

    def transform(self,X1_val, X2_val, y_val, args):
        x1, x2, y = torch.FloatTensor(X1_val).to(self.device), torch.FloatTensor(X2_val).to(
            self.device), torch.FloatTensor(y_val).to(self.device)
        if args.conditional:
            x2_hat, z, mean, log_var = self.predictor(x1, y)
        else:
            x2_hat, z, mean, log_var = self.predictor(x1)

        return z.detach().cpu().numpy()

    def get_influence_matrix(self):
        return self.predictor.get_influence_matrix()
Beispiel #2
0
def main(args):

    model_alias = 'DeepBiome_%s+%s_%s+fea1_%s+fea2_%s+bs_%s+%s' % (
        args.model,
        args.dataset_name, args.data_type,
        args.fea1,args.fea2,
        args.batch_size,
        args.extra)

    tl.configure("runs/ds.{}".format(model_alias))
    tl.log_value(model_alias, 0)

    """ no stat file needed for now
    stat_alias = 'obj_DataStat+%s_%s' % (args.dataset_name, args.dataset_subset)
    stat_path = os.path.join(
        output_dir, '%s.pkl' % (stat_alias)
    )
    with open(stat_path,'rb') as sf:
        data_stats = pickle.load(sf)
    """

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_num

    data_path = os.path.join(DATA_ROOT, "ibd_{}.pkl".format(args.data_type))


    logger.info("Initializing train dataset")

    # load data
    print('==> loading data'); print()
    (X1_train, X2_train,  y_train), (X1_val, X2_val, y_val) = load_data(data_path)
    train_loader = get_dataloader (X1_train, X2_train, y_train, args.batch_size)
    test_loader = get_dataloader(X1_val, X2_val, y_val, args.batch_size)

    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    ts = time.time()

    def loss_fn(model, recon_x, x, mean, log_var):
        if model =="AE":
            mseloss = torch.nn.MSELoss()
            return torch.sqrt(mseloss(recon_x, x))
        elif model =="VAE":
            BCE = torch.nn.functional.binary_cross_entropy(
                recon_x.view(-1, 28*28), x.view(-1, 28*28), reduction='sum')
            KLD = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
            return (BCE + KLD) / x.size(0)

    if args.model == "AE":
        predictor = AE(
            encoder_layer_sizes=[X1_train.shape[1]],
            latent_size=args.latent_size,
            decoder_layer_sizes=[X2_train.shape[1]],
            activation=args.activation,
            batch_norm= args.batch_norm,
            dropout=args.dropout,
            conditional=args.conditional,
            num_labels=10 if args.conditional else 0).to(device)
    else:
        predictor = VAE(
            encoder_layer_sizes=args.encoder_layer_sizes,
            latent_size=args.latent_size,
            decoder_layer_sizes=args.decoder_layer_sizes,
            activation=args.activation,
            batch_norm=args.batch_norm,
            dropout=args.dropout,
            conditional=args.conditional,
            num_labels=10 if args.conditional else 0).to(device)

    optimizer = torch.optim.Adam(predictor.parameters(), lr=args.learning_rate)

    logs = defaultdict(list)

    for epoch in range(args.epochs):

        tracker_epoch = defaultdict(lambda: defaultdict(dict))

        for iteration, (x1, x2, y) in enumerate(train_loader):

            x1, x2, y = x1.to(device), x2.to(device), y.to(device)

            if args.conditional:
                x2_hat, z, mean, log_var = predictor(x1, y)
            else:
                x2_hat, z, mean, log_var = predictor(x1)

            for i, yi in enumerate(y):
                id = len(tracker_epoch)
                tracker_epoch[id]['x'] = z[i, 0].item()
                tracker_epoch[id]['y'] = z[i, 1].item()
                tracker_epoch[id]['label'] = yi.item()

            loss = loss_fn(args.model, x2_hat, x2, mean, log_var)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            logs['loss'].append(loss.item())

            if iteration % args.print_every == 0 or iteration == len(train_loader)-1:
                print("Epoch {:02d}/{:02d} Batch {:04d}/{:d}, Loss {:9.4f}".format(
                    epoch, args.epochs, iteration, len(train_loader)-1, loss.item()))
                if args.model =="VAE":
                    if args.conditional:
                        c = torch.arange(0, 10).long().unsqueeze(1)
                        x = predictor.inference(n=c.size(0), c=c)
                    else:
                        x = predictor.inference(n=10)
                """
                plt.figure()
                plt.figure(figsize=(5, 10))
                for p in range(10):
                    plt.subplot(5, 2, p+1)
                    if args.conditional:
                        plt.text(
                            0, 0, "c={:d}".format(c[p].item()), color='black',
                            backgroundcolor='white', fontsize=8)
                    plt.imshow(x[p].view(28, 28).cpu().data.numpy())
                    plt.axis('off')

                if not os.path.exists(os.path.join(args.fig_root, str(ts))):
                    if not(os.path.exists(os.path.join(args.fig_root))):
                        os.mkdir(os.path.join(args.fig_root))
                    os.mkdir(os.path.join(args.fig_root, str(ts)))

                plt.savefig(
                    os.path.join(args.fig_root, str(ts),
                                 "E{:d}I{:d}.png".format(epoch, iteration)),
                    dpi=300)
                plt.clf()
                plt.close('all')
                """
        #Batch test
        x1, x2, y = torch.FloatTensor(X1_val).to(device), torch.FloatTensor(X2_val).to(device), torch.FloatTensor(y_val).to(device)
        if args.conditional:
            x2_hat, z, mean, log_var = predictor(x1, y)
        else:
            x2_hat, z, mean, log_var = predictor(x1)
        val_loss = loss_fn(args.model, x2_hat, x2, mean, log_var)
        print("val_loss: {:9.4f}", val_loss.item())
        """