示例#1
0
def main(FLAGS):
    encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    encoder.apply(weights_init)

    decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    decoder.apply(weights_init)

    # load saved models if load_saved flag is true
    if FLAGS.load_saved:
        encoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.encoder_save)))
        decoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.decoder_save)))

    device = 'cuda:0'

    decoder.to(device)
    encoder.to(device)

    tsne = TSNE(2)

    mnist = DataLoader(
        datasets.MNIST(root='mnist',
                       download=True,
                       train=False,
                       transform=transform_config))
    s_dict = {}
    with torch.no_grad():
        for i, (image, label) in enumerate(mnist):
            label = int(label)
            print(i, label)
            style_mu_1, style_logvar_1, class_latent_space_1 = encoder(
                image.to(device))
            s_dict.setdefault(label, []).append(class_latent_space_1)

    s_all = []
    for label in range(10):
        s_all.extend(s_dict[label])

    s_all = torch.cat(s_all)
    s_all = s_all.view(s_all.shape[0], -1).cpu()

    s_2d = tsne.fit_transform(s_all)

    np.savez('s_2d.npz', s_2d=s_2d)
parser.add_argument("--device", default="cuda", help="device: cuda | cpu")
parser.add_argument("--G_path", default="ckpts/G_epoch19.pt", help="path to trained state dict of generator")
parser.add_argument("--D_path", default="ckpts/D_epoch19.pt", help="path to trained state dict of discriminator")
opt = parser.parse_args()
print(opt)

img_shape = (opt.channels, opt.img_size, opt.img_size)

generator = Generator(dim = 64, zdim=opt.latent_dim, nc=opt.channels)
discriminator = Discriminator(dim = 64, zdim=opt.latent_dim, nc=opt.channels,out_feat=True)
encoder = Encoder(dim = 64, zdim=opt.latent_dim, nc=opt.channels)

generator.load_state_dict(torch.load(opt.G_path))
discriminator.load_state_dict(torch.load(opt.D_path))
generator.to(opt.device)
encoder.to(opt.device)
discriminator.to(opt.device)

encoder.train()
discriminator.train()

dataloader = load_data(opt)

generator.eval()

Tensor = torch.cuda.FloatTensor if  opt.device == 'cuda' else torch.FloatTensor

optimizer_E = torch.optim.Adam(encoder.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

max_auc = 0
示例#3
0
def training_procedure(FLAGS):
    """
    model definition
    """
    encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    encoder.apply(weights_init)

    decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    decoder.apply(weights_init)

    # load saved models if load_saved flag is true
    if FLAGS.load_saved:
        encoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.encoder_save)))
        decoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.decoder_save)))
    """
    variable definition
    """
    X = torch.FloatTensor(FLAGS.batch_size, 784)
    '''
    run on GPU if GPU is available
    '''
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    encoder.to(device=device)
    decoder.to(device=device)
    X = X.to(device=device)
    """
    optimizer definition
    """
    auto_encoder_optimizer = optim.Adam(list(encoder.parameters()) +
                                        list(decoder.parameters()),
                                        lr=FLAGS.initial_learning_rate,
                                        betas=(FLAGS.beta_1, FLAGS.beta_2))
    """
    
    """
    if torch.cuda.is_available() and not FLAGS.cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )

    if not os.path.exists('checkpoints'):
        os.makedirs('checkpoints')

    # load_saved is false when training is started from 0th iteration
    if not FLAGS.load_saved:
        with open(FLAGS.log_file, 'w') as log:
            log.write(
                'Epoch\tIteration\tReconstruction_loss\tStyle_KL_divergence_loss\tClass_KL_divergence_loss\n'
            )

    # load data set and create data loader instance
    dirs = os.listdir(os.path.join(os.getcwd(), 'data'))
    print('Loading double multivariate normal time series data...')
    for dsname in dirs:
        params = dsname.split('_')
        if params[2] in ('theta=-1'):
            print('Running dataset ', dsname)
            ds = DoubleMulNormal(dsname)
            # ds = experiment3(1000, 50, 3)
            loader = cycle(
                DataLoader(ds,
                           batch_size=FLAGS.batch_size,
                           shuffle=True,
                           drop_last=True))

            # initialize summary writer
            writer = SummaryWriter()

            for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch):
                print()
                print(
                    'Epoch #' + str(epoch) +
                    '........................................................')

                # the total loss at each epoch after running iterations of batches
                total_loss = 0

                for iteration in range(int(len(ds) / FLAGS.batch_size)):
                    # load a mini-batch
                    image_batch, labels_batch = next(loader)

                    # set zero_grad for the optimizer
                    auto_encoder_optimizer.zero_grad()

                    X.copy_(image_batch)

                    style_mu, style_logvar, class_mu, class_logvar = encoder(
                        Variable(X))
                    grouped_mu, grouped_logvar = accumulate_group_evidence(
                        class_mu.data, class_logvar.data, labels_batch,
                        FLAGS.cuda)

                    # kl-divergence error for style latent space
                    style_kl_divergence_loss = FLAGS.kl_divergence_coef * (
                        -0.5 * torch.sum(1 + style_logvar - style_mu.pow(2) -
                                         style_logvar.exp()))
                    style_kl_divergence_loss /= (FLAGS.batch_size *
                                                 FLAGS.num_channels *
                                                 FLAGS.image_size *
                                                 FLAGS.image_size)
                    style_kl_divergence_loss.backward(retain_graph=True)

                    # kl-divergence error for class latent space
                    class_kl_divergence_loss = FLAGS.kl_divergence_coef * (
                        -0.5 *
                        torch.sum(1 + grouped_logvar - grouped_mu.pow(2) -
                                  grouped_logvar.exp()))
                    class_kl_divergence_loss /= (FLAGS.batch_size *
                                                 FLAGS.num_channels *
                                                 FLAGS.image_size *
                                                 FLAGS.image_size)
                    class_kl_divergence_loss.backward(retain_graph=True)

                    # reconstruct samples
                    """
                    sampling from group mu and logvar for each image in mini-batch differently makes
                    the decoder consider class latent embeddings as random noise and ignore them 
                    """
                    style_latent_embeddings = reparameterize(
                        training=True, mu=style_mu, logvar=style_logvar)
                    class_latent_embeddings = group_wise_reparameterize(
                        training=True,
                        mu=grouped_mu,
                        logvar=grouped_logvar,
                        labels_batch=labels_batch,
                        cuda=FLAGS.cuda)

                    reconstructed_images = decoder(style_latent_embeddings,
                                                   class_latent_embeddings)

                    reconstruction_error = FLAGS.reconstruction_coef * mse_loss(
                        reconstructed_images, Variable(X))
                    reconstruction_error.backward()

                    total_loss += style_kl_divergence_loss + class_kl_divergence_loss + reconstruction_error

                    auto_encoder_optimizer.step()

                    if (iteration + 1) % 50 == 0:
                        print('\tIteration #' + str(iteration))
                        print('Reconstruction loss: ' + str(
                            reconstruction_error.data.storage().tolist()[0]))
                        print('Style KL loss: ' +
                              str(style_kl_divergence_loss.data.storage().
                                  tolist()[0]))
                        print('Class KL loss: ' +
                              str(class_kl_divergence_loss.data.storage().
                                  tolist()[0]))

                    # write to log
                    with open(FLAGS.log_file, 'a') as log:
                        log.write('{0}\t{1}\t{2}\t{3}\t{4}\n'.format(
                            epoch, iteration,
                            reconstruction_error.data.storage().tolist()[0],
                            style_kl_divergence_loss.data.storage().tolist()
                            [0],
                            class_kl_divergence_loss.data.storage().tolist()
                            [0]))

                    # write to tensorboard
                    writer.add_scalar(
                        'Reconstruction loss',
                        reconstruction_error.data.storage().tolist()[0],
                        epoch * (int(len(ds) / FLAGS.batch_size) + 1) +
                        iteration)
                    writer.add_scalar(
                        'Style KL-Divergence loss',
                        style_kl_divergence_loss.data.storage().tolist()[0],
                        epoch * (int(len(ds) / FLAGS.batch_size) + 1) +
                        iteration)
                    writer.add_scalar(
                        'Class KL-Divergence loss',
                        class_kl_divergence_loss.data.storage().tolist()[0],
                        epoch * (int(len(ds) / FLAGS.batch_size) + 1) +
                        iteration)

                    if epoch == 0 and (iteration + 1) % 50 == 0:
                        torch.save(
                            encoder.state_dict(),
                            os.path.join('checkpoints', 'encoder_' + dsname))
                        torch.save(
                            decoder.state_dict(),
                            os.path.join('checkpoints', 'decoder_' + dsname))

                # save checkpoints after every 10 epochs
                if (epoch + 1) % 10 == 0 or (epoch + 1) == FLAGS.end_epoch:
                    torch.save(
                        encoder.state_dict(),
                        os.path.join('checkpoints', 'encoder_' + dsname))
                    torch.save(
                        decoder.state_dict(),
                        os.path.join('checkpoints', 'decoder_' + dsname))

                print('Total loss at current epoch: ', total_loss.item())
示例#4
0
def test(opt):
    #### mkdir
    des_pth = os.path.join('results', opt.name)
    if os.path.exists(os.path.join(des_pth)) is not True:
        os.mkdir(des_pth)
    src_pth = os.path.join(opt.checkpoints, opt.name)

    models_name = os.listdir(src_pth)
    models_name.remove('images')
    models_name.remove('records.txt')
    models_name.sort(key=lambda x: int(x[6:9]))
    target = int(models_name[-1][6:9])

    #### device
    device = torch.device('cuda:{}'.format(opt.gpu_id) if opt.gpu_id >= 0 else torch.device('cpu'))

    #### data
    data_loader = UnAlignedDataLoader()
    data_loader.initialize(opt)
    data_set = data_loader.load_data()

    #### networks
    ## initialize
    E_a2b = Encoder(input_nc=opt.input_nc, ngf=opt.ngf, norm_type=opt.norm_type, use_dropout=not opt.no_dropout, n_blocks=9)
    G_b = Decoder(output_nc=opt.output_nc, ngf=opt.ngf, norm_type=opt.norm_type)
    E_b2a = Encoder(input_nc=opt.input_nc, ngf=opt.ngf, norm_type=opt.norm_type, use_dropout=not opt.no_dropout, n_blocks=9)
    G_a = Decoder(output_nc=opt.output_nc, ngf=opt.ngf, norm_type=opt.norm_type)

    ## load in models
    E_a2b.load_state_dict(torch.load(os.path.join(src_pth, 'epoch_%3d-E_a2b.pth'%target)))
    G_b.load_state_dict(torch.load(os.path.join(src_pth, 'epoch_%3d-G_b.pth'%target)))
    E_b2a.load_state_dict(torch.load(os.path.join(src_pth, 'epoch_%3d-E_b2a.pth' % target)))
    G_a.load_state_dict(torch.load(os.path.join(src_pth, 'epoch_%3d-G_a.pth' % target)))

    E_a2b = E_a2b.to(device)
    G_b = G_b.to(device)
    E_b2a = E_b2a.to(device)
    G_a = G_a.to(device)

    for i, data in enumerate(data_set):
        real_A = data['A'].to(device)
        real_B = data['B'].to(device)

        fake_B = G_b(E_a2b(real_A))
        fake_A = G_a(E_b2a(real_B))

        ## visualize
        if opt.gpu_id >= 0:
            fake_B = fake_B.cpu().data
            fake_A = fake_A.cpu().data

            real_A = real_A.cpu()
            real_B = real_B.cpu()

        for j in range(opt.batch_size):
            fake_b = tensor2image_RGB(fake_B[j, ...])
            fake_a = tensor2image_RGB(fake_A[j, ...])

            real_a = tensor2image_RGB(real_A[j, ...])
            real_b = tensor2image_RGB(real_B[j, ...])

            plt.subplot(221), plt.title("real_A"), plt.imshow(real_a)
            plt.subplot(222), plt.title("fake_B"), plt.imshow(fake_b)
            plt.subplot(223), plt.title("real_B"), plt.imshow(real_b)
            plt.subplot(224), plt.title("fake_A"), plt.imshow(fake_a)

            plt.savefig(os.path.join(des_pth, '%06d-%02d.jpg'%(i, j)))
        #break #-> debug

    print("≧◔◡◔≦ Congratulation! Successfully finishing the testing!")
示例#5
0
def train(opt):
    #### device
    device = torch.device('cuda:{}'.format(opt.gpu_id)
                          if opt.gpu_id >= 0 else torch.device('cpu'))

    #### dataset
    data_loader = UnAlignedDataLoader()
    data_loader.initialize(opt)
    data_set = data_loader.load_data()
    print("The number of training images = %d." % len(data_set))

    #### initialize models
    ## declaration
    E_a2Zb = Encoder(input_nc=opt.input_nc,
                     ngf=opt.ngf,
                     norm_type=opt.norm_type,
                     use_dropout=not opt.no_dropout,
                     n_blocks=9)
    G_Zb2b = Decoder(output_nc=opt.output_nc,
                     ngf=opt.ngf,
                     norm_type=opt.norm_type)
    T_Zb2Za = LatentTranslator(n_channels=256,
                               norm_type=opt.norm_type,
                               use_dropout=not opt.no_dropout)
    D_b = Discriminator(input_nc=opt.input_nc,
                        ndf=opt.ndf,
                        n_layers=opt.n_layers,
                        norm_type=opt.norm_type)

    E_b2Za = Encoder(input_nc=opt.input_nc,
                     ngf=opt.ngf,
                     norm_type=opt.norm_type,
                     use_dropout=not opt.no_dropout,
                     n_blocks=9)
    G_Za2a = Decoder(output_nc=opt.output_nc,
                     ngf=opt.ngf,
                     norm_type=opt.norm_type)
    T_Za2Zb = LatentTranslator(n_channels=256,
                               norm_type=opt.norm_type,
                               use_dropout=not opt.no_dropout)
    D_a = Discriminator(input_nc=opt.input_nc,
                        ndf=opt.ndf,
                        n_layers=opt.n_layers,
                        norm_type=opt.norm_type)

    ## initialization
    E_a2Zb = init_net(E_a2Zb, init_type=opt.init_type).to(device)
    G_Zb2b = init_net(G_Zb2b, init_type=opt.init_type).to(device)
    T_Zb2Za = init_net(T_Zb2Za, init_type=opt.init_type).to(device)
    D_b = init_net(D_b, init_type=opt.init_type).to(device)

    E_b2Za = init_net(E_b2Za, init_type=opt.init_type).to(device)
    G_Za2a = init_net(G_Za2a, init_type=opt.init_type).to(device)
    T_Za2Zb = init_net(T_Za2Zb, init_type=opt.init_type).to(device)
    D_a = init_net(D_a, init_type=opt.init_type).to(device)
    print(
        "+------------------------------------------------------+\nFinish initializing networks."
    )

    #### optimizer and criterion
    ## criterion
    criterionGAN = GANLoss(opt.gan_mode).to(device)
    criterionZId = nn.L1Loss()
    criterionIdt = nn.L1Loss()
    criterionCTC = nn.L1Loss()
    criterionZCyc = nn.L1Loss()

    ## optimizer
    optimizer_G = torch.optim.Adam(itertools.chain(E_a2Zb.parameters(),
                                                   G_Zb2b.parameters(),
                                                   T_Zb2Za.parameters(),
                                                   E_b2Za.parameters(),
                                                   G_Za2a.parameters(),
                                                   T_Za2Zb.parameters()),
                                   lr=opt.lr,
                                   betas=(opt.beta1, opt.beta2))
    optimizer_D = torch.optim.Adam(itertools.chain(D_a.parameters(),
                                                   D_b.parameters()),
                                   lr=opt.lr,
                                   betas=(opt.beta1, opt.beta2))

    ## scheduler
    scheduler = [
        get_scheduler(optimizer_G, opt),
        get_scheduler(optimizer_D, opt)
    ]

    print(
        "+------------------------------------------------------+\nFinish initializing the optimizers and criterions."
    )

    #### global variables
    checkpoints_pth = os.path.join(opt.checkpoints, opt.name)
    if os.path.exists(checkpoints_pth) is not True:
        os.mkdir(checkpoints_pth)
        os.mkdir(os.path.join(checkpoints_pth, 'images'))
    record_fh = open(os.path.join(checkpoints_pth, 'records.txt'),
                     'w',
                     encoding='utf-8')
    loss_names = [
        'GAN_A', 'Adv_A', 'Idt_A', 'CTC_A', 'ZId_A', 'ZCyc_A', 'GAN_B',
        'Adv_B', 'Idt_B', 'CTC_B', 'ZId_B', 'ZCyc_B'
    ]

    fake_A_pool = ImagePool(
        opt.pool_size
    )  # create image buffer to store previously generated images
    fake_B_pool = ImagePool(
        opt.pool_size
    )  # create image buffer to store previously generated images

    print(
        "+------------------------------------------------------+\nFinish preparing the other works."
    )
    print(
        "+------------------------------------------------------+\nNow training is beginning .."
    )
    #### training
    cur_iter = 0
    for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()  # timer for entire epoch

        for i, data in enumerate(data_set):
            ## setup inputs
            real_A = data['A'].to(device)
            real_B = data['B'].to(device)

            ## forward
            # image cycle / GAN
            latent_B = E_a2Zb(real_A)  #-> a -> Zb     : E_a2b(a)
            fake_B = G_Zb2b(latent_B)  #-> Zb -> b'    : G_b(E_a2b(a))
            latent_A = E_b2Za(real_B)  #-> b -> Za     : E_b2a(b)
            fake_A = G_Za2a(latent_A)  #-> Za -> a'    : G_a(E_b2a(b))

            # Idt
            '''
            rec_A = G_Za2a(E_b2Za(fake_B))          #-> b' -> Za' -> rec_a  : G_a(E_b2a(fake_b))
            rec_B = G_Zb2b(E_a2Zb(fake_A))          #-> a' -> Zb' -> rec_b  : G_b(E_a2b(fake_a))
            '''
            idt_latent_A = E_b2Za(real_A)  #-> a -> Za        : E_b2a(a)
            idt_A = G_Za2a(idt_latent_A)  #-> Za -> idt_a    : G_a(E_b2a(a))
            idt_latent_B = E_a2Zb(real_B)  #-> b -> Zb        : E_a2b(b)
            idt_B = G_Zb2b(idt_latent_B)  #-> Zb -> idt_b    : G_b(E_a2b(b))

            # ZIdt
            T_latent_A = T_Zb2Za(latent_B)  #-> Zb -> Za''  : T_b2a(E_a2b(a))
            T_rec_A = G_Za2a(
                T_latent_A)  #-> Za'' -> a'' : G_a(T_b2a(E_a2b(a)))
            T_latent_B = T_Za2Zb(latent_A)  #-> Za -> Zb''  : T_a2b(E_b2a(b))
            T_rec_B = G_Zb2b(
                T_latent_B)  #-> Zb'' -> b'' : G_b(T_a2b(E_b2a(b)))

            # CTC
            T_idt_latent_B = T_Za2Zb(idt_latent_A)  #-> a -> T_a2b(E_b2a(a))
            T_idt_latent_A = T_Zb2Za(idt_latent_B)  #-> b -> T_b2a(E_a2b(b))

            # ZCyc
            TT_latent_B = T_Za2Zb(T_latent_A)  #-> T_a2b(T_b2a(E_a2b(a)))
            TT_latent_A = T_Zb2Za(T_latent_B)  #-> T_b2a(T_a2b(E_b2a(b)))

            ### optimize parameters
            ## Generator updating
            set_requires_grad(
                [D_b, D_a],
                False)  #-> set Discriminator to require no gradient
            optimizer_G.zero_grad()
            # GAN loss
            loss_G_A = criterionGAN(D_b(fake_B), True)
            loss_G_B = criterionGAN(D_a(fake_A), True)
            loss_GAN = loss_G_A + loss_G_B
            # Idt loss
            loss_idt_A = criterionIdt(idt_A, real_A)
            loss_idt_B = criterionIdt(idt_B, real_B)
            loss_Idt = loss_idt_A + loss_idt_B
            # Latent cross-identity loss
            loss_Zid_A = criterionZId(T_rec_A, real_A)
            loss_Zid_B = criterionZId(T_rec_B, real_B)
            loss_Zid = loss_Zid_A + loss_Zid_B
            # Latent cross-translation consistency
            loss_CTC_A = criterionCTC(T_idt_latent_A, latent_A)
            loss_CTC_B = criterionCTC(T_idt_latent_B, latent_B)
            loss_CTC = loss_CTC_B + loss_CTC_A
            # Latent cycle consistency
            loss_ZCyc_A = criterionZCyc(TT_latent_A, latent_A)
            loss_ZCyc_B = criterionZCyc(TT_latent_B, latent_B)
            loss_ZCyc = loss_ZCyc_B + loss_ZCyc_A

            loss_G = opt.lambda_gan * loss_GAN + opt.lambda_idt * loss_Idt + opt.lambda_zid * loss_Zid + opt.lambda_ctc * loss_CTC + opt.lambda_zcyc * loss_ZCyc

            # backward and gradient updating
            loss_G.backward()
            optimizer_G.step()

            ## Discriminator updating
            set_requires_grad([D_b, D_a],
                              True)  # -> set Discriminator to require gradient
            optimizer_D.zero_grad()

            # backward D_b
            fake_B_ = fake_B_pool.query(fake_B)
            #-> real_B, fake_B
            pred_real_B = D_b(real_B)
            loss_D_real_B = criterionGAN(pred_real_B, True)

            pred_fake_B = D_b(fake_B_)
            loss_D_fake_B = criterionGAN(pred_fake_B, False)

            loss_D_B = (loss_D_real_B + loss_D_fake_B) * 0.5
            loss_D_B.backward()

            # backward D_a
            fake_A_ = fake_A_pool.query(fake_A)
            #-> real_A, fake_A
            pred_real_A = D_a(real_A)
            loss_D_real_A = criterionGAN(pred_real_A, True)

            pred_fake_A = D_a(fake_A_)
            loss_D_fake_A = criterionGAN(pred_fake_A, False)

            loss_D_A = (loss_D_real_A + loss_D_fake_A) * 0.5
            loss_D_A.backward()

            # update the gradients
            optimizer_D.step()

            ### validate here, both qualitively and quantitatively
            ## record the losses
            if cur_iter % opt.log_freq == 0:
                # loss_names = ['GAN_A', 'Adv_A', 'Idt_A', 'CTC_A', 'ZId_A', 'ZCyc_A', 'GAN_B', 'Adv_B', 'Idt_B', 'CTC_B', 'ZId_B', 'ZCyc_B']
                losses = [
                    loss_G_A.item(),
                    loss_D_A.item(),
                    loss_idt_A.item(),
                    loss_CTC_A.item(),
                    loss_Zid_A.item(),
                    loss_ZCyc_A.item(),
                    loss_G_B.item(),
                    loss_D_B.item(),
                    loss_idt_B.item(),
                    loss_CTC_B.item(),
                    loss_Zid_B.item(),
                    loss_ZCyc_B.item()
                ]
                # record
                line = ''
                for loss in losses:
                    line += '{} '.format(loss)
                record_fh.write(line[:-1] + '\n')
                # print out
                print('Epoch: %3d/%3dIter: %9d--------------------------+' %
                      (epoch, opt.epoch, i))
                field_names = loss_names[:len(loss_names) // 2]
                table = PrettyTable(field_names=field_names)
                for l_n in field_names:
                    table.align[l_n] = 'm'
                table.add_row(losses[:len(field_names)])
                print(table.get_string(reversesort=True))

                field_names = loss_names[len(loss_names) // 2:]
                table = PrettyTable(field_names=field_names)
                for l_n in field_names:
                    table.align[l_n] = 'm'
                table.add_row(losses[-len(field_names):])
                print(table.get_string(reversesort=True))

            ## visualize
            if cur_iter % opt.vis_freq == 0:
                if opt.gpu_id >= 0:
                    real_A = real_A.cpu().data
                    real_B = real_B.cpu().data
                    fake_A = fake_A.cpu().data
                    fake_B = fake_B.cpu().data
                    idt_A = idt_A.cpu().data
                    idt_B = idt_B.cpu().data
                    T_rec_A = T_rec_A.cpu().data
                    T_rec_B = T_rec_B.cpu().data

                plt.subplot(241), plt.title('real_A'), plt.imshow(
                    tensor2image_RGB(real_A[0, ...]))
                plt.subplot(242), plt.title('fake_B'), plt.imshow(
                    tensor2image_RGB(fake_B[0, ...]))
                plt.subplot(243), plt.title('idt_A'), plt.imshow(
                    tensor2image_RGB(idt_A[0, ...]))
                plt.subplot(244), plt.title('L_idt_A'), plt.imshow(
                    tensor2image_RGB(T_rec_A[0, ...]))

                plt.subplot(245), plt.title('real_B'), plt.imshow(
                    tensor2image_RGB(real_B[0, ...]))
                plt.subplot(246), plt.title('fake_A'), plt.imshow(
                    tensor2image_RGB(fake_A[0, ...]))
                plt.subplot(247), plt.title('idt_B'), plt.imshow(
                    tensor2image_RGB(idt_B[0, ...]))
                plt.subplot(248), plt.title('L_idt_B'), plt.imshow(
                    tensor2image_RGB(T_rec_B[0, ...]))

                plt.savefig(
                    os.path.join(checkpoints_pth, 'images',
                                 '%03d_%09d.jpg' % (epoch, i)))

            cur_iter += 1
            #break #-> debug

        ## till now, we finish one epoch, try to update the learning rate
        update_learning_rate(schedulers=scheduler,
                             opt=opt,
                             optimizer=optimizer_D)
        ## save the model
        if epoch % opt.ckp_freq == 0:
            #-> save models
            # torch.save(model.state_dict(), PATH)
            #-> load in models
            # model.load_state_dict(torch.load(PATH))
            # model.eval()
            if opt.gpu_id >= 0:
                E_a2Zb = E_a2Zb.cpu()
                G_Zb2b = G_Zb2b.cpu()
                T_Zb2Za = T_Zb2Za.cpu()
                D_b = D_b.cpu()

                E_b2Za = E_b2Za.cpu()
                G_Za2a = G_Za2a.cpu()
                T_Za2Zb = T_Za2Zb.cpu()
                D_a = D_a.cpu()
                '''
                torch.save( E_a2Zb.cpu().state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-E_a2b.pth' % epoch))
                torch.save( G_Zb2b.cpu().state_dict(), os.path.join(checkpoints_pth,   'epoch_%3d-G_b.pth' % epoch))
                torch.save(T_Zb2Za.cpu().state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-T_b2a.pth' % epoch))
                torch.save(    D_b.cpu().state_dict(), os.path.join(checkpoints_pth,   'epoch_%3d-D_b.pth' % epoch))

                torch.save( E_b2Za.cpu().state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-E_b2a.pth' % epoch))
                torch.save( G_Za2a.cpu().state_dict(), os.path.join(checkpoints_pth,   'epoch_%3d-G_a.pth' % epoch))
                torch.save(T_Za2Zb.cpu().state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-T_a2b.pth' % epoch))
                torch.save(    D_a.cpu().state_dict(), os.path.join(checkpoints_pth,   'epoch_%3d-D_a.pth' % epoch))
                '''
            torch.save(
                E_a2Zb.state_dict(),
                os.path.join(checkpoints_pth, 'epoch_%3d-E_a2b.pth' % epoch))
            torch.save(
                G_Zb2b.state_dict(),
                os.path.join(checkpoints_pth, 'epoch_%3d-G_b.pth' % epoch))
            torch.save(
                T_Zb2Za.state_dict(),
                os.path.join(checkpoints_pth, 'epoch_%3d-T_b2a.pth' % epoch))
            torch.save(
                D_b.state_dict(),
                os.path.join(checkpoints_pth, 'epoch_%3d-D_b.pth' % epoch))

            torch.save(
                E_b2Za.state_dict(),
                os.path.join(checkpoints_pth, 'epoch_%3d-E_b2a.pth' % epoch))
            torch.save(
                G_Za2a.state_dict(),
                os.path.join(checkpoints_pth, 'epoch_%3d-G_a.pth' % epoch))
            torch.save(
                T_Za2Zb.state_dict(),
                os.path.join(checkpoints_pth, 'epoch_%3d-T_a2b.pth' % epoch))
            torch.save(
                D_a.state_dict(),
                os.path.join(checkpoints_pth, 'epoch_%3d-D_a.pth' % epoch))
            if opt.gpu_id >= 0:
                E_a2Zb = E_a2Zb.to(device)
                G_Zb2b = G_Zb2b.to(device)
                T_Zb2Za = T_Zb2Za.to(device)
                D_b = D_b.to(device)

                E_b2Za = E_b2Za.to(device)
                G_Za2a = G_Za2a.to(device)
                T_Za2Zb = T_Za2Zb.to(device)
                D_a = D_a.to(device)
            print("+Successfully saving models in epoch: %3d.-------------+" %
                  epoch)
        #break #-> debug
    record_fh.close()
    print("≧◔◡◔≦ Congratulation! Finishing the training!")
    dirs = os.listdir(os.path.join(cwd, 'data'))
    print('Loading double univariate normal time series test data...')

    for dsname in dirs:
        params = dsname.split('_')
        if params[2] in ('theta=-1'):
            # load saved parameters of encoder and decoder
            encoder.load_state_dict(
                torch.load(os.path.join(cwd, 'checkpoints',
                                        'encoder_' + dsname),
                           map_location=lambda storage, loc: storage))
            decoder.load_state_dict(
                torch.load(os.path.join(cwd, 'checkpoints',
                                        'decoder_' + dsname),
                           map_location=lambda storage, loc: storage))
            encoder = encoder.to(device=device)
            decoder = decoder.to(device=device)

            paired_mnist = DoubleMulNormal(dsname)
            loader = cycle(
                DataLoader(paired_mnist,
                           batch_size=FLAGS.batch_size,
                           shuffle=True,
                           num_workers=0,
                           drop_last=True))
            test_data = torch.from_numpy(paired_mnist.x_test)

            ### Would recommend using this pattern
            use_gpu = torch.cuda.is_available()

            # get the true change points and create list for predicted change points
 
 # Paths
 checkpoint_path = join('results', args.experiment_name, 'checkpoint')
 test_path = join('results', args.experiment_name, 'sample_val')
 os.makedirs(test_path, exist_ok=True)
 
 # Data
 if args.dataset == 'COCO-Stuff':
     from data import COCO_Stuff
     val_dset = COCO_Stuff(args.data, mode='val')
     n_classes = COCO_Stuff.n_classes
 val_data = data.DataLoader(val_dset, batch_size = args.batch_size, shuffle=False, drop_last=False)
 
 # Models
 E = Encoder()
 E.to(device)
 G = Generator(n_classes)
 G.to(device)
 
 if args.multi_gpu:  # If trained with multi-GPU, the model needs to be loaded with multi-GPU, too.
     E = nn.DataParallel(E)
     G = nn.DataParallel(G)
     # G = convert_model(G)
 
 # Load from checkpoints
 load_epoch = args.test_epoch
 if load_epoch is None:  # Use the lastest model
     load_epoch = max(int(path.split('.')[0]) for path in listdir(checkpoint_path) if path.split('.')[0].isdigit())
 print('Loading generator from epoch {:03d}'.format(load_epoch))
 E.load_state_dict(torch.load(
     join(checkpoint_path, '{:03d}.E.pth'.format(load_epoch)),