示例#1
0
    def __init__(self, num_epochs: int = 500, cuda: bool = False, models_path: str = "./models_standard_128"):
        self.features = 16
        self.instances = 32
        self.classes = 2
        self.z_size = 100
        self.batch_size = 1
        self.workers = 5
        self.num_epochs = num_epochs
        self.cuda = cuda
        self.log_step_print = 100
        self.save_period = 10
        self.graph_builder = GraphBuilder()

        self.models_path = models_path

        self.lambdas = LambdaFeaturesCollector(self.features, self.instances)
        self.metas = MetaFeaturesCollector(self.features, self.instances)
        self.data_loader = get_loader(
            f"../loader/datasets/dprocessed_{self.features}_{self.instances}_{self.classes}/",
            self.features, self.instances, self.classes, self.metas,
            self.lambdas, self.batch_size,
            self.workers)
        self.test_loader = get_loader("../loader/datasets/dtest32/", 16, 32, 2, self.metas,
                                      self.lambdas, 228, 5,
                                      train_meta=False)

        self.generator = Generator(self.features, self.instances, self.classes,
                                   self.metas.getLength(), self.z_size)
        self.discriminator = Discriminator(self.features, self.instances, self.classes,
                                           self.metas.getLength(), self.lambdas.getLength())

        self.lr = 0.0002
        self.beta1 = 0.5
        self.beta2 = 0.999

        self.g_optimizer = optim.Adam(self.generator.parameters(),
                                      self.lr, [self.beta1, self.beta2])
        self.d_optimizer = optim.Adam(self.discriminator.parameters(),
                                      self.lr, [self.beta1, self.beta2])

        self.cross_entropy = BCEWithLogitsLoss()
        self.mse = MSELoss()
        meta_list.append(metas.getShort(data.cpu().detach().numpy()))
    result = torch.stack(meta_list)
    return to_variable(result.view((result.size(0), result.size(1), 1, 1)))


def to_variable(x):
    return Variable(x)


exp_num = 3
datasize = 64
z_size = 100
batch_size = 1
workers = 5
lambdas = LambdaFeaturesCollector(16, 64)
metas = MetaFeaturesCollector(16, 64)
dataloader = get_loader(f"../processed_data/processed_16_64_2/", 16, 64, 2,
                        metas, lambdas, batch_size, workers)
datatest = get_loader(f"../processed_data/test/",
                      16,
                      64,
                      2,
                      metas,
                      lambdas,
                      batch_size,
                      workers,
                      train_meta=False)
discriminator = Discriminator(16, 64, 2, metas.getLength(),
                              lambdas.getLength())
generator = Generator(16, 64, 2, metas.getLength(), 100)
示例#3
0
class Trainer:
    def __init__(self, num_epochs: int = 500, cuda: bool = False, models_path: str = "./models_standard_128"):
        self.features = 16
        self.instances = 32
        self.classes = 2
        self.z_size = 100
        self.batch_size = 1
        self.workers = 5
        self.num_epochs = num_epochs
        self.cuda = cuda
        self.log_step_print = 100
        self.save_period = 10
        self.graph_builder = GraphBuilder()

        self.models_path = models_path

        self.lambdas = LambdaFeaturesCollector(self.features, self.instances)
        self.metas = MetaFeaturesCollector(self.features, self.instances)
        self.data_loader = get_loader(
            f"../loader/datasets/dprocessed_{self.features}_{self.instances}_{self.classes}/",
            self.features, self.instances, self.classes, self.metas,
            self.lambdas, self.batch_size,
            self.workers)
        self.test_loader = get_loader("../loader/datasets/dtest32/", 16, 32, 2, self.metas,
                                      self.lambdas, 228, 5,
                                      train_meta=False)

        self.generator = Generator(self.features, self.instances, self.classes,
                                   self.metas.getLength(), self.z_size)
        self.discriminator = Discriminator(self.features, self.instances, self.classes,
                                           self.metas.getLength(), self.lambdas.getLength())

        self.lr = 0.0002
        self.beta1 = 0.5
        self.beta2 = 0.999

        self.g_optimizer = optim.Adam(self.generator.parameters(),
                                      self.lr, [self.beta1, self.beta2])
        self.d_optimizer = optim.Adam(self.discriminator.parameters(),
                                      self.lr, [self.beta1, self.beta2])

        self.cross_entropy = BCEWithLogitsLoss()
        self.mse = MSELoss()

    def getDistance(self, x: torch.Tensor, y: torch.Tensor) -> [float]:
        x_in = np.squeeze(x.cpu().detach().numpy())
        y_in = np.squeeze(y.cpu().detach().numpy())
        results = []
        for (xx, yy) in zip(x_in, y_in):
            try:
                V = np.cov(np.array([xx, yy]).T)
                V[np.diag_indices_from(V)] += 0.1
                IV = np.linalg.inv(V)
                D = mahalanobis(xx, yy, IV)
            except:
                D = 0.0
            results.append(D)
        return results

    def getMeta(self, data_in: torch.Tensor):
        meta_list = []
        for data in data_in:
            meta_list.append(self.metas.getShort(data.cpu().detach().numpy()))
        result = torch.stack(meta_list)
        return Variable(result.view((result.size(0), result.size(1), 1, 1)))

    def getLambda(self, data_in: torch.Tensor):
        lamba_list = []
        for data in data_in:
            lamba_list.append(self.lambdas.get(data.cpu().detach().numpy()))
        result = torch.stack(lamba_list)
        return Variable(result)

    def get_d_real(self, dataset, metas, lambdas, zeros):
        graph1, graph2 = self.graph_builder.build_complete_graph(dataset)
        real_outputs = self.discriminator(graph1, graph2, metas)
        #real_outputs = self.discriminator(dataset, metas)

        d_real_labels_loss = self.mse(real_outputs[1:], lambdas)

        d_real_rf_loss = self.mse(real_outputs[:1], zeros)  #
        return d_real_labels_loss + 0.7 * d_real_rf_loss, d_real_labels_loss, d_real_rf_loss

    def get_d_fake(self, dataset, noise, metas, ones):
        fake_data = self.generator(noise, metas)
        fake_data_metas = self.getMeta(fake_data)

        graph1, graph2 = self.graph_builder.build_complete_graph(dataset)
        fake_outputs = self.discriminator(graph1, graph2, fake_data_metas)
        #fake_outputs = self.discriminator(fake_data, fake_data_metas)
        fake_lambdas = self.getLambda(fake_data).squeeze()
        d_fake_labels_loss = self.cross_entropy(fake_outputs[1:], fake_lambdas)
        d_fake_rf_loss = self.mse(fake_outputs[:1], ones)
        return 0.7 * d_fake_rf_loss + 0.6 * d_fake_labels_loss, d_fake_labels_loss, d_fake_rf_loss

    def train(self):
        total_steps = len(self.data_loader)
        g_loss_epochs = []
        d_loss_epochs = []
        for epoch in range(self.num_epochs):
            loss = []
            max_len = len(self.data_loader)
            g_loss_epoch1 = []
            d_loss_epoch1 = []
            for i, data in enumerate(self.data_loader):
                dataset = Variable(data[0])
                metas = Variable(data[1])
                lambdas = Variable(data[2]).squeeze()
                batch_size = data[0].size(0)
                noise = torch.randn(batch_size, self.z_size)
                noise = noise.view((noise.size(0), noise.size(1), 1, 1))
                noise = Variable(noise)
                zeros = torch.zeros([batch_size, 1], dtype=torch.float32)
                zeros = Variable(zeros)
                ones = torch.ones([batch_size, 1], dtype=torch.float32)
                ones = Variable(ones)

                d_real_loss, d_real_labels_loss, d_real_rf_loss = \
                    self.get_d_real(dataset, metas, lambdas, zeros)

                d_fake_loss, d_fake_labels_loss, d_fake_rf_loss = \
                    self.get_d_fake(dataset, noise, metas, ones)

                d_loss = d_real_loss + 0.8 * d_fake_loss
                self.generator.zero_grad()
                self.discriminator.zero_grad()
                d_loss.backward()
                self.d_optimizer.step()

                noise = torch.randn(batch_size, self.z_size)
                noise = noise.view(noise.size(0), noise.size(1), 1, 1)
                noise = Variable(noise)
                fake_data = self.generator(noise, metas)

                graph1, graph2 = self.graph_builder.build_complete_graph(dataset)
                fake_outputs = self.discriminator(graph1, graph2, metas)
                #fake_outputs = self.discriminator(fake_data, metas)
                g_fake_rf_loss = self.mse(fake_outputs[:1], zeros)
                fake_metas = self.getMeta(fake_data)
                g_fake_meta_loss = self.mse(fake_metas, metas)
                g_loss = 0.7 * g_fake_rf_loss + g_fake_meta_loss
                g_loss_epoch1.append(g_loss)
                d_loss_epoch1.append(d_loss)
                # minimize log(1 - D(G(z)))
                self.generator.zero_grad()
                self.discriminator.zero_grad()
                g_loss.backward()
                self.g_optimizer.step()

                if (i + 1) % self.log_step_print == 0:
                    print((
                        f'[{datetime.now()}] Epoch[{epoch}/{self.num_epochs}], Step[{i}/{total_steps}],\n'
                        f' D_losses: [{d_real_rf_loss}|{d_real_labels_loss}|{d_fake_rf_loss}|{d_fake_labels_loss}],\n'
                        f'G_losses:[{g_fake_rf_loss}|{g_fake_meta_loss}]'
                    ))

                if i == total_steps - 1:
                    print("Intermediate result - ")
                    print((
                        f'[{datetime.now()}] Epoch[{epoch}/{self.num_epochs}], Step[{i}/{total_steps}],\n'
                        f' D_losses: [{d_real_rf_loss}|{d_real_labels_loss}|{d_fake_rf_loss}|{d_fake_labels_loss}],\n'
                        f'G_losses:[{g_fake_rf_loss}|{g_fake_meta_loss}]'
                    ))
            d_loss_epochs.append((sum(d_loss_epoch1)) / total_steps)
            g_loss_epochs.append((sum(g_loss_epoch1)) / total_steps)
            # saving
            if (epoch + 1) % self.save_period == 0:
                done_data_str_path = Path(self.models_path)
                done_data_str_path.mkdir(parents=True, exist_ok=True)
                g_path = os.path.join(self.models_path,
                                      f'generator-{self.features}_{self.instances}_{self.classes}-{epoch + 1}.pkl')
                d_path = os.path.join(self.models_path,
                                      f'discriminator-{self.features}_{self.instances}_{self.classes}-{epoch + 1}.pkl')
                torch.save(self.generator.state_dict(), g_path)
                torch.save(self.discriminator.state_dict(), d_path)

        with open(os.path.join(self.models_path, 'g_loss.pickle'), 'wb') as g_loss_file:
            pickle.dump(g_loss_epochs, g_loss_file)
        with open(os.path.join(self.models_path, 'd_loss.pickle'), 'wb') as d_loss_file:
            pickle.dump(d_loss_epochs, d_loss_file)
示例#4
0
    def __init__(self,
                 num_epochs: int = 500,
                 cuda: bool = True,
                 continue_from: int = 0):
        self.features = 16
        self.instances = 64
        self.classes = 2
        self.z_size = 100
        self.batch_size = 100
        self.workers = 5
        self.num_epochs = num_epochs
        self.cuda = cuda
        self.log_step = 10
        self.log_step_print = 50
        self.save_period = 5
        self.continue_from = continue_from

        self.models_path = "./models_grid"

        self.lambdas = LambdaFeaturesCollector(self.features, self.instances)
        self.metas = MetaFeaturesCollector(self.features, self.instances)
        self.data_loader = get_loader(
            f"../processed_data/processed_{self.features}_{self.instances}_{self.classes}/",
            self.features, self.instances, self.classes, self.metas,
            self.lambdas, self.batch_size, self.workers)
        self.test_loader = get_loader(f"../processed_data/test/",
                                      16,
                                      64,
                                      2,
                                      self.metas,
                                      self.lambdas,
                                      228,
                                      5,
                                      train_meta=False)

        if continue_from == 0:
            self.generator = Generator(self.features,
                                       self.instances, self.classes,
                                       self.metas.getLength(), self.z_size)
            self.discriminator = Discriminator(self.features, self.instances,
                                               self.classes,
                                               self.metas.getLength(),
                                               self.lambdas.getLength())
        else:
            self.generator = Generator(self.features,
                                       self.instances, self.classes,
                                       self.metas.getLength(), self.z_size)
            self.generator.load_state_dict(
                torch.load(
                    f'{self.models_path}/generator-{self.features}_{self.instances}_{self.classes}-{continue_from}.pkl'
                ))
            self.generator.eval()

            self.discriminator = Discriminator(self.features, self.instances,
                                               self.classes,
                                               self.metas.getLength(),
                                               self.lambdas.getLength())
            self.discriminator.load_state_dict(
                torch.load(
                    f'{self.models_path}/discriminator-{self.features}_{self.instances}_{self.classes}-{continue_from}.pkl'
                ))
            self.discriminator.eval()

        if self.cuda:
            self.generator.cuda()

        if self.cuda:
            self.discriminator.cuda()

        self.lr = 0.0002
        self.beta1 = 0.5
        self.beta2 = 0.999

        self.g_optimizer = optim.Adam(self.generator.parameters(), self.lr,
                                      [self.beta1, self.beta2])
        self.d_optimizer = optim.Adam(self.discriminator.parameters(), self.lr,
                                      [self.beta1, self.beta2])

        self.cross_entropy = BCEWithLogitsLoss()
        if self.cuda:
            self.cross_entropy.cuda()
        self.mse = MSELoss()
        if self.cuda:
            self.mse.cuda()
示例#5
0
class Trainer:
    def __init__(self,
                 num_epochs: int = 500,
                 cuda: bool = True,
                 continue_from: int = 0):
        self.features = 16
        self.instances = 64
        self.classes = 2
        self.z_size = 100
        self.batch_size = 100
        self.workers = 5
        self.num_epochs = num_epochs
        self.cuda = cuda
        self.log_step = 10
        self.log_step_print = 50
        self.save_period = 5
        self.continue_from = continue_from

        self.models_path = "./models_grid"

        self.lambdas = LambdaFeaturesCollector(self.features, self.instances)
        self.metas = MetaFeaturesCollector(self.features, self.instances)
        self.data_loader = get_loader(
            f"../processed_data/processed_{self.features}_{self.instances}_{self.classes}/",
            self.features, self.instances, self.classes, self.metas,
            self.lambdas, self.batch_size, self.workers)
        self.test_loader = get_loader(f"../processed_data/test/",
                                      16,
                                      64,
                                      2,
                                      self.metas,
                                      self.lambdas,
                                      228,
                                      5,
                                      train_meta=False)

        if continue_from == 0:
            self.generator = Generator(self.features,
                                       self.instances, self.classes,
                                       self.metas.getLength(), self.z_size)
            self.discriminator = Discriminator(self.features, self.instances,
                                               self.classes,
                                               self.metas.getLength(),
                                               self.lambdas.getLength())
        else:
            self.generator = Generator(self.features,
                                       self.instances, self.classes,
                                       self.metas.getLength(), self.z_size)
            self.generator.load_state_dict(
                torch.load(
                    f'{self.models_path}/generator-{self.features}_{self.instances}_{self.classes}-{continue_from}.pkl'
                ))
            self.generator.eval()

            self.discriminator = Discriminator(self.features, self.instances,
                                               self.classes,
                                               self.metas.getLength(),
                                               self.lambdas.getLength())
            self.discriminator.load_state_dict(
                torch.load(
                    f'{self.models_path}/discriminator-{self.features}_{self.instances}_{self.classes}-{continue_from}.pkl'
                ))
            self.discriminator.eval()

        if self.cuda:
            self.generator.cuda()

        if self.cuda:
            self.discriminator.cuda()

        self.lr = 0.0002
        self.beta1 = 0.5
        self.beta2 = 0.999

        self.g_optimizer = optim.Adam(self.generator.parameters(), self.lr,
                                      [self.beta1, self.beta2])
        self.d_optimizer = optim.Adam(self.discriminator.parameters(), self.lr,
                                      [self.beta1, self.beta2])

        self.cross_entropy = BCEWithLogitsLoss()
        if self.cuda:
            self.cross_entropy.cuda()
        self.mse = MSELoss()
        if self.cuda:
            self.mse.cuda()

    def to_variable(self, x):
        if self.cuda:
            x = x.cuda()
        return Variable(x)

    def getDistance(self, x: torch.Tensor, y: torch.Tensor) -> [float]:
        x_in = np.squeeze(x.cpu().detach().numpy())
        y_in = np.squeeze(y.cpu().detach().numpy())
        results = []
        for (xx, yy) in zip(x_in, y_in):
            try:
                V = np.cov(np.array([xx, yy]).T)
                V[np.diag_indices_from(V)] += 0.1
                IV = np.linalg.inv(V)
                D = mahalanobis(xx, yy, IV)
            except:
                D = 0.0
            results.append(D)
        return results

    def getMeta(self, data_in: torch.Tensor):
        meta_list = []
        for data in data_in:
            meta_list.append(self.metas.getShort(data.cpu().detach().numpy()))
        result = torch.stack(meta_list)
        return self.to_variable(
            result.view((result.size(0), result.size(1), 1, 1)))

    def getLambda(self, data_in: torch.Tensor):
        lamba_list = []
        for data in data_in:
            lamba_list.append(self.lambdas.get(data.cpu().detach().numpy()))
        result = torch.stack(lamba_list)
        return self.to_variable(result)

    def train(self):
        total_steps = len(self.data_loader)
        logging.info(f'Starting training...')
        for epoch in range(self.continue_from, self.num_epochs):
            loss = []
            q = 0
            for i, data in enumerate(self.test_loader):
                q += 1
                dataset = self.to_variable(data[0])
                metas = self.to_variable(data[1])
                lambdas = self.to_variable(data[2])
                real_outputs = self.discriminator(dataset, metas)
                d_real_labels_loss = self.mse(real_outputs[:, 1:], lambdas)
                loss.append(d_real_labels_loss.cpu().detach().numpy())
            logging.info(f'{epoch}d:{np.mean(loss)}')
            results = []
            q = 0
            for i, data in enumerate(self.test_loader):
                q += 1
                metas = self.to_variable(data[1])
                batch_size = data[0].size(0)
                noise = torch.randn(batch_size, 100)
                noise = noise.view((noise.size(0), noise.size(1), 1, 1))
                noise = self.to_variable(noise)

                fake_data = self.generator(noise, metas)
                fake_metas = self.getMeta(fake_data)
                results.extend(self.mse(fake_metas, metas))
            logging.info(f'{epoch}g:{np.mean(np.array(results))}')

            q = 0
            for i, data in enumerate(self.data_loader):
                q += 1
                dataset = self.to_variable(data[0])
                metas = self.to_variable(data[1])
                lambdas = self.to_variable(data[2])
                batch_size = data[0].size(0)
                noise = torch.randn(batch_size, self.z_size)
                noise = noise.view((noise.size(0), noise.size(1), 1, 1))
                noise = self.to_variable(noise)
                zeros = torch.zeros([batch_size, 1], dtype=torch.float32)
                zeros = self.to_variable(zeros)
                ones = torch.ones([batch_size, 1], dtype=torch.float32)
                ones = self.to_variable(ones)

                # Get D on real
                real_outputs = self.discriminator(dataset, metas)
                d_real_labels_loss = self.mse(real_outputs[:, 1:], lambdas)
                d_real_rf_loss = self.mse(real_outputs[:, :1], zeros)
                d_real_loss = d_real_labels_loss + 0.7 * d_real_rf_loss

                # Get D on fake
                fake_data = self.generator(noise, metas)
                fake_data_metas = self.getMeta(fake_data)
                fake_outputs = self.discriminator(fake_data, fake_data_metas)
                fake_lambdas = self.getLambda(fake_data)
                d_fake_labels_loss = self.mse(fake_outputs[:, 1:],
                                              fake_lambdas)
                d_fake_rf_loss = self.mse(fake_outputs[:, :1], ones)
                d_fake_loss = 0.7 * d_fake_rf_loss + 0.6 * d_fake_labels_loss

                # Train D
                d_loss = d_real_loss + 0.8 * d_fake_loss
                self.generator.zero_grad()
                self.discriminator.zero_grad()
                d_loss.backward()
                self.d_optimizer.step()

                # Get D on fake
                noise = torch.randn(batch_size, self.z_size)
                noise = noise.view(noise.size(0), noise.size(1), 1, 1)
                noise = self.to_variable(noise)
                fake_data = self.generator(noise, metas)
                fake_outputs = self.discriminator(fake_data, metas)
                g_fake_rf_loss = self.mse(fake_outputs[:, :1], zeros)
                fake_metas = self.getMeta(fake_data)
                g_fake_meta_loss = self.mse(fake_metas, metas)
                g_loss = 0.7 * g_fake_rf_loss + g_fake_meta_loss

                # Train G
                self.generator.zero_grad()
                self.discriminator.zero_grad()
                g_loss.backward()
                self.g_optimizer.step()

                # logging
                #if (q + 1) % self.log_step == 0:
                log = (
                    f'[[{epoch},{i}],[{d_real_rf_loss},{d_real_labels_loss},{d_fake_rf_loss},{d_fake_labels_loss}],[{g_fake_rf_loss},{g_fake_meta_loss}]]'
                )
                logging.info(log)
                #if (q + 1) % self.log_step_print == 0:
                print((
                    f'[{datetime.now()}] Epoch[{epoch}/{self.num_epochs}], Step[{q}/{total_steps}],'
                    f' D_losses: [{d_real_rf_loss}|{d_real_labels_loss}|{d_fake_rf_loss}|{d_fake_labels_loss}], '
                    f'G_losses:[{g_fake_rf_loss}|{g_fake_meta_loss}]'))

            # saving
            if (epoch + 1) % self.save_period == 0:
                done_data_str_path = Path(self.models_path)
                done_data_str_path.mkdir(parents=True, exist_ok=True)
                g_path = os.path.join(
                    self.models_path,
                    f'generator-{self.features}_{self.instances}_{self.classes}-{epoch + 1}.pkl'
                )
                d_path = os.path.join(
                    self.models_path,
                    f'discriminator-{self.features}_{self.instances}_{self.classes}-{epoch + 1}.pkl'
                )
                torch.save(self.generator.state_dict(), g_path)
                torch.save(self.discriminator.state_dict(), d_path)
示例#6
0
from feature_extraction.MetaFeaturesCollector import MetaFeaturesCollector
from sklearn.neighbors import KNeighborsClassifier
import os
import numpy as np
import time

os.environ['KMP_DUPLICATE_LIB_OK']='True'

if __name__ == '__main__':
    np.random.seed(int(time.time()))
    datasize = 64
    z_size = 100
    batch_size = 1
    workers = 5
    lambdas = LambdaFeaturesCollector(16, 64)
    metas = MetaFeaturesCollector(16, 64)
    dataloader = get_loader(f"../processed_data/processed_16_64_2/", 16, 64, 2, metas, lambdas, batch_size, workers)
    datatest = get_loader(f"../processed_data/test/", 16, 64, 2, metas, lambdas, batch_size, workers, train_meta=False)

    meta_list = []
    lambdas_list = []
    for i, (data, meta, lambda_l) in enumerate(dataloader):
        meta_o = meta[:, :].numpy()
        meta_o = meta_o.ravel()
        meta_o = meta_o.tolist()
        meta_list.append(meta_o)
        lambdas_o = lambda_l[:, :].numpy().astype(int).ravel().tolist()
        lambdas_list.append(lambdas_o)

    meta_list_test = []
    lambdas_list_test = []
示例#7
0
    results = []
    for (xx, yy) in zip(x_in, y_in):
        try:
            V = np.cov(np.array([xx, yy]).T)
            V[np.diag_indices_from(V)] += 0.1
            IV = np.linalg.inv(V)
            D = mahalanobis(xx, yy, IV)
        except:
            D = 0.0
        results.append(D)
    return results


if __name__ == '__main__':
    exp_num = 3
    metaCollector = MetaFeaturesCollector(16, 64)
    metaCollector.train(f"../processed_data/processed_16_64_2/")
    lambdas = LambdaFeaturesCollector(16, 64)
    loader = get_loader(f"../processed_data/test/",
                        16,
                        64,
                        2,
                        metaCollector,
                        lambdas,
                        100,
                        5,
                        train_meta=False)
    generator = Generator(16, 64, 2, metaCollector.getLength(), 100)
    methods = [
        'models_base', 'models_diag', 'models_corp', 'models_cors',
        'models_tspg', 'models_tsph'
    disp = sum([(xi - avg) * (xi - avg) for xi in loss]) / n
    print(np.mean(loss))
    print(disp)


def start_test(train_dir, test_dir, discriminator_location):
    batch_size = 1
    workers = 5
    data_train = get_loader(train_dir, 16, datasize, 2, metas, lambdas, batch_size, workers)
    data_test = get_loader(test_dir, 16, datasize, 2, metas, lambdas, batch_size, workers,
                           train_meta=False)

    _, train_m, train_l = get_metas_and_lambdas(data_train)
    datasets, test_m, test_l = get_metas_and_lambdas(data_test)

    test_classifiers(train_m, train_l, test_m, test_l)
    test_discriminators(data_test, datasets, test_m, test_l, discriminator_location)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--train-dir", default="../loader/datasets/dprocessed_16_32_2/")
    parser.add_argument("--test-dir", default="../loader/datasets/dtest32/")
    parser.add_argument("-d", "--disc-location",
                        default="../modifiedLMGAN/models_fullgraph/discriminator-16_32_2-20.pkl")
    datasize = 32
    lambdas = LambdaFeaturesCollector(16, datasize)
    metas = MetaFeaturesCollector(16, datasize)
    args = parser.parse_args()
    start_test(args.train_dir, args.test_dir, args.disc_location)