Ejemplo n.º 1
0
    def generate(self, weights_path, number_images=100):
        generator = net_utils.generator(self.config.LATENT_DIM,
                                        self.config.IMAGE_SHAPE,
                                        self.config.NUMBER_RESIDUAL_BLOCKS,
                                        base_name="generator")

        generator.load_weights(weights_path)

        now = datetime.datetime.now()
        datetime_sequence = "{0}{1:02d}{2:02d}_{3:02}{4:02d}".format(
            str(now.year)[-2:], now.month, now.day, now.hour, now.minute)

        output_dir = os.path.join("generated", datetime_sequence)
        os.makedirs(output_dir, exist_ok=True)

        counter = 0

        while counter < number_images:
            noise = np.random.normal(
                size=(16, self.config.LATENT_DIM)).astype('float32')
            generated_images = generator.predict(noise)
            for i in range(16):
                file = "{}.jpg".format(counter)
                utils.output_sample_image(os.path.join(output_dir, file),
                                          generated_images[i, :, :, :])
                counter += 1
                if counter >= number_images:
                    break
Ejemplo n.º 2
0
    def inference(self, model_path, target_dir):
        """

        :param model_path:
        :param target_dir:
        :return:
        """
        # TODO: ファイルごとにモデルを作り直しているため非常に遅い。なんとかする。
        path_list = glob(os.path.join(target_dir, self.config.DATA_EXT))

        output_dir_name = os.path.join("translated",
                                       os.path.basename(target_dir))
        os.makedirs(output_dir_name, exist_ok=True)

        for image_path in path_list:
            image = utils.imread(image_path)
            shape = image.shape
            input_layer = Input(shape=shape)
            G = net_utils.mapping_function(
                shape,
                base_name="G",
                num_res_blocks=self.config.NUMBER_RESIDUAL_BLOCKS)
            A2B = G(input_layer)
            inference_model = Model(inputs=[input_layer], outputs=[A2B])
            inference_model.load_weights(model_path, by_name=True)

            image = np.array([image])
            translated_image = inference_model.predict(image)
            name = os.path.basename(image_path)
            output_path = os.path.join(output_dir_name, name)
            utils.output_sample_image(output_path, translated_image[0])
Ejemplo n.º 3
0
    def train_iterations(self, counter=0):

        now = datetime.datetime.now()
        datetime_sequence = "{0}{1:02d}{2:02d}_{3:02}{4:02d}".format(
            str(now.year)[-2:], now.month, now.day, now.hour, now.minute)
        file_list = glob(
            os.path.join(self.config.DATA_DIR, self.config.DATASET,
                         self.config.DATA_EXT))

        random.seed(42)
        random.shuffle(file_list)

        val_ratio = 0.1
        train_file_list = file_list[round(len(file_list) * val_ratio):]
        val_file_list = file_list[:round(len(file_list) * val_ratio)]

        dataset = utils.data_generator(train_file_list, self.config.BATCH_SIZE)

        experiment_dir = os.path.join(
            self.config.RESULT_DIR,
            datetime_sequence + "_" + self.config.COMMENT)

        sample_output_dir = os.path.join(experiment_dir, "sample",
                                         self.config.DATASET)
        weights_output_dir = os.path.join(experiment_dir, "weights",
                                          self.config.DATASET)
        weights_output_dir_resume = os.path.join(experiment_dir, "weights",
                                                 "resume")

        os.makedirs(sample_output_dir, exist_ok=True)
        os.makedirs(weights_output_dir, exist_ok=True)
        os.makedirs(weights_output_dir_resume, exist_ok=True)

        self.config.output_config(os.path.join(experiment_dir, "config.txt"))

        start_time = time.time()
        met_curve = pd.DataFrame(columns=[
            "counter", "loss_d", "loss_d_real", "loss_d_fake", "loss_g"
        ])

        train_val_curve = pd.DataFrame(
            columns=["counter", "train_loss_d", "val_loss_d"])

        fixed_noise = np.random.normal(
            size=(16, self.config.LATENT_DIM)).astype('float32')

        for epoch in range(self.config.EPOCH):
            for iter in range(self.config.ITER_PER_EPOCH):
                for _ in range(self.config.NUM_CRITICS):
                    batch_files = next(dataset)
                    real_batch = np.array([
                        utils.get_image(file,
                                        input_hw=self.config.IMAGE_SHAPE[0])
                        for file in batch_files
                    ])

                    noise = np.random.normal(size=(self.config.BATCH_SIZE,
                                                   self.config.LATENT_DIM))
                    epsilon = np.random.uniform(size=(self.config.BATCH_SIZE,
                                                      1, 1, 1))
                    errD_real, errD_fake = self.D_train(
                        [real_batch, noise, epsilon])
                    errD = errD_real - errD_fake

                noise = np.random.normal(size=(self.config.BATCH_SIZE,
                                               self.config.LATENT_DIM))
                errG, = self.G_train([noise])

                elapsed = time.time() - start_time

                print("epoch {0} {1}/{2} loss_d:{3:.4f} loss_d_real:{4:.4f} "
                      "loss_d_fake:{5:.4f}, loss_g:{6:.4f}, {7:.2f}秒".format(
                          epoch, iter, 1000, errD, errD_real, errD_fake, errG,
                          elapsed))

                if counter % 10 == 0:
                    temp_df = pd.DataFrame({
                        "counter": [counter],
                        "loss_d": [errD],
                        "loss_d_real": [errD_real],
                        "loss_d_fake": [errD_fake],
                        "loss_g": [errG]
                    })
                    met_curve = pd.concat([met_curve, temp_df], axis=0)

                if counter % 500 == 0:

                    # validation lossの計算
                    val_D_real = 0
                    val_D_fake = 0

                    val_size = len(val_file_list)
                    for i in range(val_size // self.config.BATCH_SIZE):
                        val_files = val_file_list[i *
                                                  self.config.BATCH_SIZE:(i +
                                                                          1) *
                                                  self.config.BATCH_SIZE]
                        val_batch = np.array([
                            utils.get_image(
                                file, input_hw=self.config.IMAGE_SHAPE[0])
                            for file in val_files
                        ])
                        val_D_real += np.mean(
                            self.discriminator.predict(val_batch))
                        noise = np.random.normal(size=(self.config.BATCH_SIZE,
                                                       self.config.LATENT_DIM))
                        val_D_fake += np.mean(
                            self.discriminator.predict(
                                self.generator.predict(noise)))
                    if not val_size % self.config.BATCH_SIZE == 0:
                        val_files = val_file_list[-val_size %
                                                  self.config.BATCH_SIZE:]
                        val_batch = np.array([
                            utils.get_image(
                                file, input_hw=self.config.IMAGE_SHAPE[0])
                            for file in val_files
                        ])
                        val_D_real += np.mean(
                            self.discriminator.predict(val_batch))
                        noise = np.random.normal(size=(val_size %
                                                       self.config.BATCH_SIZE,
                                                       self.config.LATENT_DIM))
                        val_D_fake += np.mean(
                            self.discriminator.predict(
                                self.generator.predict(noise)))

                    val_loss = (val_D_real - val_D_fake) / val_size
                    temp_df = pd.DataFrame({
                        "counter": [counter],
                        "train_loss_d": [errD],
                        "val_loss_d": [val_loss]
                    })
                    train_val_curve = pd.concat([train_val_curve, temp_df],
                                                axis=0)
                    train_val_curve.to_csv(os.path.join(
                        experiment_dir, self.config.DATASET + "_val.csv"),
                                           index=False)

                    # sample の出力
                    sample = self.generator.predict(fixed_noise)
                    h, w, c = self.config.IMAGE_SHAPE
                    sample_array = np.zeros((4 * h, 4 * w, 3))
                    for n in range(16):
                        i = n // 4
                        j = n % 4
                        sample_array[i * h:(i + 1) * h,
                                     j * w:(j + 1) * w, :] = sample[n, :, :, :]
                    file = "{0}_{1}.jpg".format(epoch, counter)
                    utils.output_sample_image(
                        os.path.join(sample_output_dir, file), sample_array)

                if counter % 5000 == 0:
                    self.generator.save(
                        os.path.join(weights_output_dir,
                                     "generator" + str(counter) + ".hdf5"))
                    self.discriminator.save(
                        os.path.join(weights_output_dir,
                                     "discriminator" + str(counter) + ".hdf5"))
                    met_curve.to_csv(os.path.join(
                        experiment_dir, self.config.DATASET + ".csv"),
                                     index=False)

                counter += 1

        sample = generator.predict(fixed_noise)
        h, w, c = self.config.IMAGE_SHAPE
        sample_array = np.zeros((4 * h, 4 * w, 3))
        for n in range(16):
            i = n // 4
            j = n % 4
            sample_array[i * h:(i + 1) * h, j * w:(j + 1) * w,
                         3] = sample[n, :, :, :]
        file = "{0}_{1}.jpg".format(self.config.EPOCH, counter)
        utils.output_sample_image(os.path.join(sample_output_dir, file),
                                  sample_array)

        self.generator.save(
            os.path.join(weights_output_dir,
                         "generator" + str(counter) + ".hdf5"))
        self.discriminator.save(
            os.path.join(weights_output_dir,
                         "discriminator" + str(counter) + ".hdf5"))
        met_curve.to_csv(os.path.join(
            experiment_dir,
            self.config.DATASET_A + "_" + self.config.DATASET_B + ".csv"),
                         index=False)
Ejemplo n.º 4
0
    def train(self):
        num_channels = self.config.NUM_CHANNELS
        use_cuda = self.config.USE_CUDA
        lr = self.config.LEARNING_RATE

        # Networks
        netG_A2B = Generator(num_channels)
        netG_B2A = Generator(num_channels)
        netD_A = Discriminator(num_channels)
        netD_B = Discriminator(num_channels)

        #netG_A2B = Generator_BN(num_channels)
        #netG_B2A = Generator_BN(num_channels)
        #netD_A = Discriminator_BN(num_channels)
        #netD_B = Discriminator_BN(num_channels)

        if use_cuda:
            netG_A2B.cuda()
            netG_B2A.cuda()
            netD_A.cuda()
            netD_B.cuda()

        netG_A2B.apply(weights_init_normal)
        netG_B2A.apply(weights_init_normal)
        netD_A.apply(weights_init_normal)
        netD_B.apply(weights_init_normal)

        criterion_GAN = torch.nn.BCELoss()
        criterion_cycle = torch.nn.L1Loss()
        criterion_identity = torch.nn.L1Loss()

        optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()),
                                       lr=lr, betas=(0.5, 0.999))
        optimizer_D_A = torch.optim.Adam(netD_A.parameters(), lr=lr, betas=(0.5, 0.999))
        optimizer_D_B = torch.optim.Adam(netD_B.parameters(), lr=lr, betas=(0.5, 0.999))

        lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=LambdaLR(self.config.EPOCH, 0,
                                                                                           self.config.EPOCH//2).step)
        lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=LambdaLR(self.config.EPOCH, 0,
                                                                                           self.config.EPOCH//2).step)
        lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=LambdaLR(self.config.EPOCH, 0,
                                                                                           self.config.EPOCH//2).step)

        # Inputs & targets memory allocation
        #Tensor = LongTensor if use_cuda else torch.Tensor
        batch_size = self.config.BATCH_SIZE
        height, width, channels = self.config.INPUT_SHAPE

        input_A = FloatTensor(batch_size, channels, height, width)
        input_B = FloatTensor(batch_size, channels, height, width)
        target_real = Variable(FloatTensor(batch_size).fill_(1.0), requires_grad=False)
        target_fake = Variable(FloatTensor(batch_size).fill_(0.0), requires_grad=False)

        fake_A_buffer = ReplayBuffer()
        fake_B_buffer = ReplayBuffer()

        transforms_ = [transforms.RandomCrop((height, width)),
                       transforms.RandomHorizontalFlip(),
                       transforms.ToTensor(),
                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]

        dataloader = DataLoader(ImageDataset(self.config.DATA_DIR, self.config.DATASET_A, self.config.DATASET_B,
                                             transforms_=transforms_, unaligned=True),
                                             batch_size=batch_size, shuffle=True, num_workers=2, drop_last=True)
        # Loss plot
        logger = Logger(self.config.EPOCH, len(dataloader))

        now = datetime.datetime.now()
        datetime_sequence = "{0}{1:02d}{2:02d}_{3:02}{4:02d}".format(str(now.year)[-2:], now.month, now.day ,
                                                                    now.hour, now.minute)

        output_name_1 = self.config.DATASET_A + "2" + self.config.DATASET_B
        output_name_2 = self.config.DATASET_B + "2" + self.config.DATASET_A

        experiment_dir = os.path.join(self.config.RESULT_DIR, datetime_sequence)

        sample_output_dir_1 = os.path.join(experiment_dir, "sample", output_name_1)
        sample_output_dir_2 = os.path.join(experiment_dir, "sample", output_name_2)
        weights_output_dir_1 = os.path.join(experiment_dir, "weights", output_name_1)
        weights_output_dir_2 = os.path.join(experiment_dir, "weights", output_name_2)
        weights_output_dir_resume = os.path.join(experiment_dir, "weights", "resume")

        os.makedirs(sample_output_dir_1, exist_ok=True)
        os.makedirs(sample_output_dir_2, exist_ok=True)
        os.makedirs(weights_output_dir_1, exist_ok=True)
        os.makedirs(weights_output_dir_2, exist_ok=True)
        os.makedirs(weights_output_dir_resume, exist_ok=True)

        counter = 0

        for epoch in range(self.config.EPOCH):
            """
            logger.loss_df.to_csv(os.path.join(experiment_dir,
                                 self.config.DATASET_A + "_"
                                 + self.config.DATASET_B + ".csv"),
                    index=False)
            """
            if epoch % 100 == 0:
                torch.save(netG_A2B.state_dict(), os.path.join(weights_output_dir_1, str(epoch).zfill(4) + 'netG_A2B.pth'))
                torch.save(netG_B2A.state_dict(), os.path.join(weights_output_dir_2, str(epoch).zfill(4) + 'netG_B2A.pth'))
                torch.save(netD_A.state_dict(), os.path.join(weights_output_dir_1, str(epoch).zfill(4) + 'netD_A.pth'))
                torch.save(netD_B.state_dict(), os.path.join(weights_output_dir_2, str(epoch).zfill(4) + 'netD_B.pth'))

            for i, batch in enumerate(dataloader):
                # Set model input
                real_A = Variable(input_A.copy_(batch['A']))
                real_B = Variable(input_B.copy_(batch['B']))

                ###### Generators A2B and B2A ######
                optimizer_G.zero_grad()

                # GAN loss
                fake_B = netG_A2B(real_A)
                pred_fake_B = netD_B(fake_B)
                loss_GAN_A2B = criterion_GAN(pred_fake_B, target_real)

                fake_A = netG_B2A(real_B)
                pred_fake_A = netD_A(fake_A)
                loss_GAN_B2A = criterion_GAN(pred_fake_A, target_real)

                # Cycle loss
                recovered_A = netG_B2A(fake_B)
                loss_cycle_ABA = criterion_cycle(recovered_A, real_A) * 10.0

                recovered_B = netG_A2B(fake_A)
                loss_cycle_BAB = criterion_cycle(recovered_B, real_B) * 10.0

                # Total loss
                loss_G = loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB
                loss_G.backward()

                optimizer_G.step()
                ###################################

                ###### Discriminator A ######
                optimizer_D_A.zero_grad()

                # Real loss
                pred_A = netD_A(real_A)
                loss_D_real = criterion_GAN(pred_A, target_real)

                # Fake loss
                fake_A_ = fake_A_buffer.push_and_pop(fake_A)
                pred_fake = netD_A(fake_A_.detach())
                loss_D_fake = criterion_GAN(pred_fake, target_fake)

                # Total loss
                loss_D_A = (loss_D_real + loss_D_fake) * 0.5
                loss_D_A.backward()

                optimizer_D_A.step()
                ###################################

                ###### Discriminator B ######
                optimizer_D_B.zero_grad()

                # Real loss
                pred_B = netD_B(real_B)
                loss_D_real = criterion_GAN(pred_B, target_real)

                # Fake loss
                fake_B_ = fake_B_buffer.push_and_pop(fake_B)
                pred_fake = netD_B(fake_B_.detach())
                loss_D_fake = criterion_GAN(pred_fake, target_fake)

                # Total loss
                loss_D_B = (loss_D_real + loss_D_fake) * 0.5
                loss_D_B.backward()

                optimizer_D_B.step()

                # Progress report (http://localhost:8097)
                logger.log({'loss_G': loss_G,
                            'loss_G_GAN': (loss_GAN_A2B + loss_GAN_B2A),
                            'loss_G_cycle': (loss_cycle_ABA + loss_cycle_BAB), 'loss_D': (loss_D_A + loss_D_B)},
                           images={'real_A': real_A, 'real_B': real_B, 'fake_A': fake_A, 'fake_B': fake_B})

                if counter % 500 == 0:
                    real_A_sample = real_A.cpu().detach().numpy()[0]
                    pred_A_sample = fake_A.cpu().detach().numpy()[0]
                    real_B_sample = real_B.cpu().detach().numpy()[0]
                    pred_B_sample = fake_B.cpu().detach().numpy()[0]
                    combine_sample_1 = np.concatenate([real_A_sample, pred_B_sample], axis=2)
                    combine_sample_2 = np.concatenate([real_B_sample, pred_A_sample], axis=2)

                    file_1 = "{0}_{1}.jpg".format(epoch, counter)
                    output_sample_image(os.path.join(sample_output_dir_1, file_1), combine_sample_1)
                    file_2 = "{0}_{1}.jpg".format(epoch, counter)
                    output_sample_image(os.path.join(sample_output_dir_2, file_2), combine_sample_2)

                counter += 1


            # Update learning rates
            lr_scheduler_G.step()
            lr_scheduler_D_A.step()
            lr_scheduler_D_B.step()

        torch.save(netG_A2B.state_dict(), os.path.join(weights_output_dir_1, str(self.config.EPOCH).zfill(4) + 'netG_A2B.pth'))
        torch.save(netG_B2A.state_dict(), os.path.join(weights_output_dir_2, str(self.config.EPOCH).zfill(4) + 'netG_B2A.pth'))
        torch.save(netD_A.state_dict(), os.path.join(weights_output_dir_1, str(self.config.EPOCH).zfill(4) + 'netD_A.pth'))
        torch.save(netD_B.state_dict(), os.path.join(weights_output_dir_2, str(self.config.EPOCH).zfill(4) + 'netD_B.pth'))
Ejemplo n.º 5
0
    def training_iteration(self, Dx, Dy, generator, G, F, counter=0):

        now = datetime.datetime.now()
        datetime_sequence = "{0}{1:02d}{2:02d}_{3:02}{4:02d}".format(
            str(now.year)[-2:], now.month, now.day, now.hour, now.minute)

        datasetA = utils.data_generator(
            os.path.join(self.config.DATA_DIR, self.config.DATASET_A,
                         self.config.DATA_EXT), self.config.BATCH_SIZE)
        datasetB = utils.data_generator(
            os.path.join(self.config.DATA_DIR, self.config.DATASET_B,
                         self.config.DATA_EXT), self.config.BATCH_SIZE)

        output_name_1 = self.config.DATASET_A + "2" + self.config.DATASET_B
        output_name_2 = self.config.DATASET_B + "2" + self.config.DATASET_A

        experiment_dir = os.path.join(self.config.RESULT_DIR,
                                      datetime_sequence)

        sample_output_dir_1 = os.path.join(experiment_dir, "sample",
                                           output_name_1)
        sample_output_dir_2 = os.path.join(experiment_dir, "sample",
                                           output_name_2)
        weights_output_dir_1 = os.path.join(experiment_dir, "weights",
                                            output_name_1)
        weights_output_dir_2 = os.path.join(experiment_dir, "weights",
                                            output_name_2)
        weights_output_dir_resume = os.path.join(experiment_dir, "weights",
                                                 "resume")

        os.makedirs(sample_output_dir_1, exist_ok=True)
        os.makedirs(sample_output_dir_2, exist_ok=True)
        os.makedirs(weights_output_dir_1, exist_ok=True)
        os.makedirs(weights_output_dir_2, exist_ok=True)
        os.makedirs(weights_output_dir_resume, exist_ok=True)

        start_time = time.time()
        met_curve = pd.DataFrame(columns=[
            "counter", "loss_Dx", "loss_Dy", "loss_G", "adversarial_loss",
            "cycle_loss"
        ])

        for epoch in range(self.config.EPOCH):
            for iter in range(self.config.ITER_PER_EPOCH):
                # generate minibatch
                batch_files_A = next(datasetA)
                batch_A = np.array([
                    utils.get_image(file, input_hw=self.config.INPUT_SHAPE[0])
                    for file in batch_files_A
                ])
                batch_files_B = next(datasetB)
                batch_B = np.array([
                    utils.get_image(file, input_hw=self.config.INPUT_SHAPE[0])
                    for file in batch_files_B
                ])
                # update generator's gradients
                loss_g = generator.train_on_batch(
                    x={
                        "input_A": batch_A,
                        "input_B": batch_B
                    },
                    y={
                        "Dy": np.ones(self.config.BATCH_SIZE),
                        "Dx": np.ones(self.config.BATCH_SIZE),
                        "G": batch_B,
                        "F": batch_A
                    })

                # update Dx's gradients
                X = F.predict(batch_B)
                X = np.append(batch_A, X, axis=0)
                y = [0] * len(batch_B) + [1] * len(batch_A)
                y = np.array(y)
                loss_d_x, acc_d_x = Dx.train_on_batch(X, y)

                # update Dy's gradients
                X = G.predict(batch_A)
                X = np.append(batch_B, X, axis=0)
                y = [0] * len(batch_A) + [1] * len(batch_B)
                y = np.array(y)
                loss_d_y, acc_d_y = Dy.train_on_batch(X, y)

                elapsed = time.time() - start_time

                print("epoch {0} {1}/{2} loss_d_x:{3:.4f} loss_d_y:{4:.4f} "
                      "loss_g:{5:.4f} {7:.2f}秒".format(epoch, iter, 1000,
                                                       loss_d_x, loss_d_y,
                                                       loss_g[0], loss_g[4],
                                                       elapsed))

                if counter % 10 == 0:
                    temp_df = pd.DataFrame({
                        "counter": [counter],
                        "loss_Dx": [loss_d_x],
                        "loss_Dy": [loss_d_y],
                        "loss_G": [loss_g[0]],
                        "adversarial_loss": [loss_g[1] + loss_g[2]],
                        "cycle_loss": [loss_g[3] + loss_g[4]]
                    })
                    met_curve = pd.concat([met_curve, temp_df], axis=0)

                if counter % 200 == 0:
                    sample_1 = G.predict(batch_A)
                    sample_2 = F.predict(batch_B)
                    combine_sample_1 = np.concatenate(
                        [batch_A[0], sample_1[0]], axis=1)
                    combine_sample_2 = np.concatenate(
                        [batch_B[0], sample_2[0]], axis=1)

                    file_1 = "{0}_{1}.jpg".format(epoch, counter)
                    utils.output_sample_image(
                        os.path.join(sample_output_dir_1, file_1),
                        combine_sample_1)
                    file_2 = "{0}_{1}.jpg".format(epoch, counter)
                    utils.output_sample_image(
                        os.path.join(sample_output_dir_2, file_2),
                        combine_sample_2)

                if counter % 1000 == 0:

                    net_utils.save_weights(G, weights_output_dir_1, counter)
                    net_utils.save_weights(F, weights_output_dir_2, counter)

                    net_utils.save_weights(generator,
                                           weights_output_dir_resume,
                                           counter,
                                           base_name="generator")
                    net_utils.save_weights(Dx,
                                           weights_output_dir_resume,
                                           counter,
                                           base_name="Dy")
                    net_utils.save_weights(Dy,
                                           weights_output_dir_resume,
                                           counter,
                                           base_name="Dx")

                    met_curve.to_csv(os.path.join(
                        experiment_dir, self.config.DATASET_A + "_" +
                        self.config.DATASET_B + ".csv"),
                                     index=False)

                counter += 1

        net_utils.save_weights(G, weights_output_dir_1, counter)
        net_utils.save_weights(F, weights_output_dir_2, counter)

        net_utils.save_weights(generator,
                               weights_output_dir_resume,
                               counter,
                               base_name="generator")
        net_utils.save_weights(Dx,
                               weights_output_dir_resume,
                               counter,
                               base_name="Dy")
        net_utils.save_weights(Dy,
                               weights_output_dir_resume,
                               counter,
                               base_name="Dx")

        met_curve.to_csv(os.path.join(
            experiment_dir,
            self.config.DATASET_A + "_" + self.config.DATASET_B + ".csv"),
                         index=False)