Пример #1
0
def test(args):
    hyparam_list = [("model", args.model_name), ("cube", args.cube_len),
                    ("bs", args.batch_size), ("g_lr", args.g_lr),
                    ("d_lr", args.d_lr), ("z", args.z_dis),
                    ("bias", args.bias), ("sl", args.soft_label)]

    hyparam_dict = OrderedDict(((arg, value) for arg, value in hyparam_list))
    log_param = make_hyparam_string(hyparam_dict)
    print(log_param)

    # model define
    D = _D(args)
    G = _G(args)

    D_solver = optim.Adam(D.parameters(), lr=args.d_lr, betas=args.beta)
    G_solver = optim.Adam(G.parameters(), lr=args.g_lr, betas=args.beta)

    if torch.cuda.is_available():
        print("using cuda")
        D.cuda()
        G.cuda()

    pickle_path = args.pickle_dir
    read_pickle(pickle_path, G, G_solver, D, D_solver)  # load the models

    Z = generateZ(args)

    fake = G(Z)

    samples = fake.cpu().data[:8].squeeze().numpy()
    image_path = args.output_dir + args.image_dir + log_param
    if not os.path.exists(image_path):
        os.makedirs(image_path)
    SavePloat_Voxels(samples, image_path, iteration)
def test_3DVAEGAN(args):
    # datset define
    dsets_path = args.input_dir + args.data_dir + "test/"
    print(dsets_path)
    dsets = ShapeNetPlusImageDataset(dsets_path, args)
    dset_loaders = torch.utils.data.DataLoader(dsets,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=1)

    # model define
    E = _E(args)
    G = _G(args)

    G_solver = optim.Adam(G.parameters(), lr=args.g_lr, betas=args.beta)
    E_solver = optim.Adam(E.parameters(), lr=args.g_lr, betas=args.beta)

    if torch.cuda.is_available():
        print("using cuda")
        G.cuda()
        E.cuda()

    pickle_path = "." + args.pickle_dir + '3DVAEGAN'
    read_pickle(pickle_path, G, G_solver, G, G_solver, E, E_solver)
    recon_loss_total = 0
    for i, (image, model_3d) in enumerate(dset_loaders):

        X = var_or_cuda(model_3d)
        image = var_or_cuda(image)

        z_mu, z_var = E(image)
        Z_vae = E.reparameterize(z_mu, z_var)
        G_vae = G(Z_vae)

        recon_loss = torch.sum(torch.pow((G_vae - X), 2), dim=(1, 2, 3))
        print(recon_loss.size())
        print("RECON LOSS ITER: ", i, " - ", torch.mean(recon_loss))
        recon_loss_total += (recon_loss)
        samples = G_vae.cpu().data[:8].squeeze().numpy()

        image_path = args.output_dir + args.image_dir + '3DVAEGAN_test'
        if not os.path.exists(image_path):
            os.makedirs(image_path)

        SavePloat_Voxels(samples, image_path, i)
Пример #3
0
def test_3DGAN(args):
    # datset define
    dsets_path = args.input_dir + args.data_dir + "test/"
    print(dsets_path)
    dsets = ShapeNetDataset(dsets_path, args)
    dset_loaders = torch.utils.data.DataLoader(dsets,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=1)

    # model define
    D = _D(args)
    G = _G(args)

    D_solver = optim.Adam(D.parameters(), lr=args.d_lr, betas=args.beta)
    G_solver = optim.Adam(G.parameters(), lr=args.g_lr, betas=args.beta)

    if torch.cuda.is_available():
        print("using cuda")
        D.cuda()
        G.cuda()

    pickle_path = "." + args.pickle_dir + '3DVAEGAN_MULTIVIEW_MAX'
    read_pickle(pickle_path, G, G_solver, D, D_solver)
    recon_loss_total = 0
    for i, X in enumerate(dset_loaders):
        #X = X.view(-1, 1, args.cube_len, args.cube_len, args.cube_len)
        X = var_or_cuda(X)
        print(X.size())
        Z = generateZ(args)
        print(Z.size())
        fake = G(Z).squeeze()
        print(fake.size())
        recon_loss = torch.sum(torch.pow((fake - X), 2), dim=(1, 2, 3))
        print(recon_loss.size())
        print("RECON LOSS ITER: ", i, " - ", torch.mean(recon_loss))
        recon_loss_total += (recon_loss)
        samples = fake.cpu().data[:8].squeeze().numpy()

        image_path = args.output_dir + args.image_dir + '3DVAEGAN_MULTIVIEW_MAX_test'
        if not os.path.exists(image_path):
            os.makedirs(image_path)

        SavePloat_Voxels(samples, image_path, i)
Пример #4
0
def train(args):

    hyparam_list = [
        ("model", args.model_name),
        ("cube", args.cube_len),
        ("bs", args.batch_size),
        ("g_lr", args.g_lr),
        ("d_lr", args.d_lr),
        ("z", args.z_dis),
        ("bias", args.bias),
        ("sl", args.soft_label),
    ]

    hyparam_dict = OrderedDict(((arg, value) for arg, value in hyparam_list))
    log_param = make_hyparam_string(hyparam_dict)
    print(log_param)

    # for using tensorboard
    if args.use_tensorboard:
        import tensorflow as tf

        summary_writer = tf.summary.FileWriter(args.output_dir + args.log_dir +
                                               log_param)

        def inject_summary(summary_writer, tag, value, step):
            summary = tf.Summary(
                value=[tf.Summary.Value(tag=tag, simple_value=value)])
            summary_writer.add_summary(summary, global_step=step)

        inject_summary = inject_summary

    # datset define
    dsets_path = args.input_dir + args.data_dir + "train/"
    print(dsets_path)
    dsets = ShapeNetDataset(dsets_path, args)
    dset_loaders = torch.utils.data.DataLoader(dsets,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=1)

    # model define
    D = _D(args)
    G = _G(args)

    D_solver = optim.Adam(D.parameters(), lr=args.d_lr, betas=args.beta)
    G_solver = optim.Adam(G.parameters(), lr=args.g_lr, betas=args.beta)

    if args.lrsh:
        D_scheduler = MultiStepLR(D_solver, milestones=[500, 1000])

    if torch.cuda.is_available():
        print("using cuda")
        D.cuda()
        G.cuda()

    criterion = nn.BCELoss()

    pickle_path = "." + args.pickle_dir + log_param
    read_pickle(pickle_path, G, G_solver, D, D_solver)

    for epoch in range(args.n_epochs):
        for i, X in enumerate(dset_loaders):

            X = var_or_cuda(X)

            if X.size()[0] != int(args.batch_size):
                #print("batch_size != {} drop last incompatible batch".format(int(args.batch_size)))
                continue

            Z = generateZ(args)
            real_labels = var_or_cuda(torch.ones(args.batch_size))
            fake_labels = var_or_cuda(torch.zeros(args.batch_size))

            if args.soft_label:
                real_labels = var_or_cuda(
                    torch.Tensor(args.batch_size).uniform_(0.7, 1.2))
                fake_labels = var_or_cuda(
                    torch.Tensor(args.batch_size).uniform_(0, 0.3))

            # ============= Train the discriminator =============#
            d_real = D(X)
            d_real_loss = criterion(d_real, real_labels)

            fake = G(Z)
            d_fake = D(fake)
            d_fake_loss = criterion(d_fake, fake_labels)

            d_loss = d_real_loss + d_fake_loss

            d_real_acu = torch.ge(d_real.squeeze(), 0.5).float()
            d_fake_acu = torch.le(d_fake.squeeze(), 0.5).float()
            d_total_acu = torch.mean(torch.cat((d_real_acu, d_fake_acu), 0))

            if d_total_acu <= args.d_thresh:
                D.zero_grad()
                d_loss.backward()
                D_solver.step()

            # =============== Train the generator ===============#

            Z = generateZ(args)

            fake = G(Z)
            d_fake = D(fake)
            g_loss = criterion(d_fake, real_labels)

            D.zero_grad()
            G.zero_grad()
            g_loss.backward()
            G_solver.step()

        # =============== logging each iteration ===============#
        iteration = str(G_solver.state_dict()['state'][
            G_solver.state_dict()['param_groups'][0]['params'][0]]['step'])
        if args.use_tensorboard:
            log_save_path = args.output_dir + args.log_dir + log_param
            if not os.path.exists(log_save_path):
                os.makedirs(log_save_path)

            info = {
                'loss/loss_D_R': d_real_loss.data[0],
                'loss/loss_D_F': d_fake_loss.data[0],
                'loss/loss_D': d_loss.data[0],
                'loss/loss_G': g_loss.data[0],
                'loss/acc_D': d_total_acu.data[0]
            }

            for tag, value in info.items():
                inject_summary(summary_writer, tag, value, iteration)

            summary_writer.flush()

        # =============== each epoch save model or save image ===============#
        print(
            'Iter-{}; , D_loss : {:.4}, G_loss : {:.4}, D_acu : {:.4}, D_lr : {:.4}'
            .format(iteration, d_loss.data[0], g_loss.data[0],
                    d_total_acu.data[0],
                    D_solver.state_dict()['param_groups'][0]["lr"]))

        if (epoch + 1) % args.image_save_step == 0:

            samples = fake.cpu().data[:8].squeeze().numpy()

            image_path = args.output_dir + args.image_dir + log_param
            if not os.path.exists(image_path):
                os.makedirs(image_path)

            SavePloat_Voxels(samples, image_path, iteration)

        if (epoch + 1) % args.pickle_step == 0:
            pickle_save_path = args.output_dir + args.pickle_dir + log_param
            save_new_pickle(pickle_save_path, iteration, G, G_solver, D,
                            D_solver)

        if args.lrsh:

            try:

                D_scheduler.step()

            except Exception as e:

                print("fail lr scheduling", e)
Пример #5
0
def train(args):
    # set devices
    torch.cuda.set_device(args.gpu)
    torch.backends.cudnn.benchmark = True
    device = torch.device("cuda:{}".format(args.gpu))

    # Get data loader ##################################################
    image_transform = transforms.Compose([
        transforms.Resize(args.imageSize),
        transforms.CenterCrop(args.imageSize)
    ])
    dataset = TextDataset(args.dataroot,
                          'train',
                          base_size=args.imageSize,
                          transform=image_transform,
                          input_channels=args.input_channels,
                          image_type=args.image_type)
    dset_loaders = torch.utils.data.DataLoader(dataset,
                                               batch_size=args.batch_size,
                                               drop_last=True,
                                               shuffle=True,
                                               num_workers=args.workers)

    # ============== #
    # Define D and G #
    # ============== #
    D = _D(args)
    D_frame_motion = _D_frame_motion(args)
    G = _G(args)

    if args.cuda:
        # print("using cuda")
        if args.gpu_num > 0:
            device_ids = range(args.gpu, args.gpu + args.gpu_num)
            D = nn.DataParallel(D, device_ids=device_ids)
            D_frame_motion = nn.DataParallel(D_frame_motion,
                                             device_ids=device_ids)
            G = nn.DataParallel(G, device_ids=device_ids)

    if args.checkpoint_D != '':
        D.load_state_dict(torch.load(args.checkpoint_D))
    if args.checkpoint_frame_motion_D != '':
        D_frame_motion.load_state_dict(
            torch.load(args.checkpoint_frame_motion_D))
    if args.checkpoint_G != '':
        G.load_state_dict(torch.load(args.checkpoint_G))

    # ================================================== #
    # Load text and image encoder and fix the parameters #
    # ================================================== #
    text_encoder = RNN_ENCODER(dataset.n_words,
                               nhidden=args.hidden_size,
                               batch_size=args.batch_size)
    image_encoder = CNN_ENCODER(args.hidden_size, args.input_channels)

    if args.checkpoint_text_encoder != '':
        text_encoder.load_state_dict(torch.load(args.checkpoint_text_encoder))
    if args.checkpoint_image_encoder != '':
        image_encoder.load_state_dict(torch.load(
            args.checkpoint_image_encoder))

    for p in text_encoder.parameters():
        p.requires_grad = False
    print('Load text encoder from: %s' % args.checkpoint_text_encoder)
    text_encoder.eval()

    for p in image_encoder.parameters():
        p.requires_grad = False
    print('Load image encoder from: %s' % args.checkpoint_image_encoder)
    image_encoder.eval()

    # criterion of update
    D_solver = optim.Adam(D.parameters(), lr=args.d_lr, betas=args.beta)
    D_solver_frame = optim.Adam(D_frame_motion.parameters(),
                                lr=args.d_lr,
                                betas=args.beta)
    G_solver = optim.Adam(G.parameters(), lr=args.g_lr, betas=args.beta)

    # loss
    criterion = nn.BCELoss()
    criterion_l1 = nn.L1Loss()
    # criterion_l2 = nn.MSELoss(size_average=False)

    if args.cuda:
        print("using cuda")
        # if args.gpu_num>1:
        #     device_ids = range(args.gpu, args.gpu+args.gpu_num)
        #     D = nn.DataParallel(D, device_ids = device_ids)
        #     D_frame_motion = nn.DataParallel(D_frame_motion, device_ids = device_ids)
        #     G = nn.DataParallel(G, device_ids = device_ids)
        # text_encoder = nn.DataParallel(text_encoder, device_ids = device_ids)
        # image_encoder = nn.DataParallel(image_encoder, device_ids = device_ids)
        D.cuda()
        D_frame_motion.cuda()
        G.cuda()
        text_encoder.cuda()
        image_encoder.cuda()
        criterion = criterion.cuda()
        criterion_l1 = criterion_l1.cuda()
        # criterion_l2 = criterion_l2.cuda()

    # # print the parameters of model
    # params=G.state_dict()
    # for k, v in params.items():
    #     print(k)
    # print(params['deconv.0.weight'])

    if args.cls:
        print("using cls")
    if args.A:
        print("using A")
    # if args.C:
    #     print("using C")
    if args.video_loss:
        print("using video discriminator")
    if args.frame_motion_loss:
        print("using frame and motion discriminator")

    # estimate time
    start = time.time()

    # # generator labels
    # if args.cuda:
    #     real_labels_G = torch.ones(args.batch_size).cuda()
    #     real_labels_G_frame = torch.ones(args.batch_size*args.frame_num).cuda()
    #     real_labels_G_motion = torch.ones(args.batch_size*(args.frame_num-1)).cuda()
    # else:
    #     real_labels_G = torch.ones(args.batch_size)
    #     real_labels_G_frame = torch.ones(args.batch_size*args.frame_num)
    #     real_labels_G_motion = torch.ones(args.batch_size*(args.frame_num-1))

    # video labels
    if args.soft_label:
        if args.cuda:
            real_labels = torch.Tensor(args.batch_size).uniform_(0.7,
                                                                 1.2).cuda()
            fake_labels = torch.Tensor(args.batch_size).uniform_(0, 0.3).cuda()
        else:
            real_labels = torch.Tensor(args.batch_size).uniform_(0.7, 1.2)
            fake_labels = torch.Tensor(args.batch_size).uniform_(0, 0.3)
    else:
        if args.cuda:
            real_labels = torch.ones(args.batch_size).cuda()
            fake_labels = torch.zeros(args.batch_size).cuda()
        else:
            real_labels = torch.ones(args.batch_size)
            fake_labels = torch.zeros(args.batch_size)

    # frame labels
    if args.soft_label:
        if args.cuda:
            real_labels_frame = torch.Tensor(
                args.batch_size * args.frame_num).uniform_(0.7, 1.2).cuda()
            fake_labels_frame = torch.Tensor(
                args.batch_size * args.frame_num).uniform_(0, 0.3).cuda()
        else:
            real_labels_frame = torch.Tensor(
                args.batch_size * args.frame_num).uniform_(0.7, 1.2)
            fake_labels_frame = torch.Tensor(args.batch_size *
                                             args.frame_num).uniform_(0, 0.3)
    else:
        if args.cuda:
            real_labels_frame = torch.ones(args.batch_size *
                                           args.frame_num).cuda()
            fake_labels_frame = torch.zeros(args.batch_size *
                                            args.frame_num).cuda()
        else:
            real_labels_frame = torch.ones(args.batch_size * args.frame_num)
            fake_labels_frame = torch.zeros(args.batch_size * args.frame_num)

    # motion labels
    if args.A:
        if args.soft_label:
            if args.cuda:
                real_labels_motion = torch.Tensor(
                    args.batch_size * (args.frame_num - 1)).uniform_(
                        0.7, 1.2).cuda()
                fake_labels_motion = torch.Tensor(
                    args.batch_size * (args.frame_num - 1)).uniform_(
                        0, 0.3).cuda()
            else:
                real_labels_motion = torch.Tensor(
                    args.batch_size * (args.frame_num - 1)).uniform_(0.7, 1.2)
                fake_labels_motion = torch.Tensor(
                    args.batch_size * (args.frame_num - 1)).uniform_(0, 0.3)
        else:
            if args.cuda:
                real_labels_motion = torch.ones(args.batch_size *
                                                (args.frame_num - 1)).cuda()
                fake_labels_motion = torch.zeros(args.batch_size *
                                                 (args.frame_num - 1)).cuda()
            else:
                real_labels_motion = torch.ones(args.batch_size *
                                                (args.frame_num - 1))
                fake_labels_motion = torch.zeros(args.batch_size *
                                                 (args.frame_num - 1))

    # matching labels
    if args.cuda:
        match_labels = torch.LongTensor(range(args.batch_size)).cuda()
    else:
        match_labels = torch.LongTensor(range(args.batch_size))

    best_fid = 10000.0
    best_epoch = 0
    # iteration
    for epoch in range(args.n_epochs):
        for i, data in enumerate(dset_loaders, 0):
            if i % 10 == 0:
                print('Epoch[{}][{}/{}] Time: {}'.format(
                    epoch, i, len(dset_loaders),
                    timeSince(
                        start,
                        float(epoch * len(dset_loaders) + i) /
                        float(args.n_epochs * len(dset_loaders)))))

            # ========================= #
            #  Prepare training data    #
            # ========================= #

            X, captions, cap_lens, class_ids, keys, \
            X_wrong, caps_wrong, cap_len_wrong, cls_id_wrong, key_wrong \
            = prepare_data_real_wrong(data)

            X = X[0]
            X_wrong = X_wrong[0]

            # if args.gpu_num > 1:
            #     hidden = text_encoder.module.init_hidden()
            # else:
            #     hidden = text_encoder.init_hidden()

            hidden = text_encoder.init_hidden()

            # words_embs: batch_size x nef x seq_len
            # sent_emb: batch_size x nef
            words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
            words_embs, sent_emb = words_embs.detach(), sent_emb.detach()

            # mask = (captions==0)
            # num_words = words_embs.size(2)
            # if mask.size(1) > num_words:
            #     mask = mask[:, :num_words]

            # ==================== #
            # Generate fake images #
            # ==================== #
            # generate input noize Z (size: batch_size x 100)
            Z = generateZ(args)

            fake, mu, logvar = G(Z, sent_emb)
            # fake, mu, logvar = G(Z, sent_emb, words_embs, mask)

            # ========================= #
            #  Train the discriminator  #
            # ========================= #

            # ====== video discriminator ====== #

            D.zero_grad()

            # real
            d_real = D(X, sent_emb)
            d_real_loss = criterion(d_real, real_labels)

            # wrong
            d_wrong = D(X_wrong, sent_emb)
            d_wrong_loss = criterion(d_wrong, fake_labels)

            # fake
            d_fake = D(fake, sent_emb)
            d_fake_loss = criterion(d_fake, fake_labels)

            # final
            d_loss = d_real_loss + d_fake_loss + d_wrong_loss

            # # tricks
            # d_real_acu = torch.ge(d_real.squeeze(), 0.5).float()
            # d_fake_acu = torch.le(d_fake.squeeze(), 0.5).float()
            # d_total_acu = torch.mean(torch.cat((d_real_acu, d_fake_acu),0))
            # # if d_total_acu <= args.d_thresh:
            d_loss.backward()
            D_solver.step()

            # ====== frame and motion discriminator ====== #

            D_frame_motion.zero_grad()

            # d_real_frame, d_real_motion = D_frame_motion(X, embedding)
            d_real_frame, d_real_motion = D_frame_motion(X, sent_emb)
            d_real_loss_frame = criterion(d_real_frame, real_labels_frame)
            d_real_loss_motion = criterion(d_real_motion, real_labels_motion)

            # d_wrong_frame, d_wrong_motion = D_frame_motion(X_wrong, embedding)
            d_wrong_frame, d_wrong_motion = D_frame_motion(X_wrong, sent_emb)
            d_wrong_loss_frame = criterion(d_wrong_frame, fake_labels_frame)
            d_wrong_loss_motion = criterion(d_wrong_motion, fake_labels_motion)

            # fake = G(Z)
            fake, mu, logvar = G(Z, sent_emb)
            # fake, mu, logvar = G(Z, sent_emb, words_embs, mask)

            # d_fake_frame, d_fake_motion = D_frame_motion(fake, embedding)
            d_fake_frame, d_fake_motion = D_frame_motion(fake, sent_emb)
            d_fake_loss_frame = criterion(d_fake_frame, fake_labels_frame)
            d_fake_loss_motion = criterion(d_fake_motion, fake_labels_motion)


            d_loss_frame = d_real_loss_frame + d_fake_loss_frame + d_wrong_loss_frame \
                            + d_real_loss_motion + d_wrong_loss_motion + d_fake_loss_motion

            # # tricks
            # d_real_acu_frame = torch.ge(d_real_frame.squeeze(), 0.5).float()
            # d_fake_acu_frame = torch.le(d_fake_frame.squeeze(), 0.5).float()
            # d_total_acu_frame = torch.mean(torch.cat((d_real_acu_frame, d_fake_acu_frame),0))
            # # if d_total_acu_frame <= args.d_thresh:
            d_loss_frame.backward()
            D_solver_frame.step()

            # ===================== #
            #  Train the generator  #
            # ===================== #

            G.zero_grad()

            # generate fake samples
            fake, mu, logvar = G(Z, sent_emb)
            # fake, mu, logvar = G(Z, sent_emb, words_embs, mask)

            # calculate the loss
            d_fake = D(fake, sent_emb)
            g_loss = criterion(d_fake, real_labels)
            # g_loss = criterion(d_fake, real_labels_G)

            # frame and motion
            d_fake_frame, d_fake_motion = D_frame_motion(fake, sent_emb)
            g_loss_frame = criterion(d_fake_frame, real_labels_frame)
            g_loss_motion = criterion(d_fake_motion, real_labels_motion)

            # # frame and motion
            # d_fake_frame, d_fake_motion = D_frame_motion(fake, sent_emb)
            # g_loss_frame = criterion(d_fake_frame, real_labels_G_frame)
            # g_loss_motion = criterion(d_fake_motion, real_labels_G_motion)

            # add to g_loss
            g_loss = g_loss + g_loss_frame + g_loss_motion

            # (DAMSM)
            # words_features: batch_size x nef x 17 x 17
            # sent_code: batch_size x nef
            region_features, cnn_code = image_encoder(fake)

            # the loss of each word
            w_loss0, w_loss1, _ = words_loss(region_features, words_embs,
                                             match_labels, cap_lens, class_ids,
                                             args.batch_size)
            w_loss = (w_loss0 + w_loss1) * args.lamb

            # the loss of the whole sentence
            s_loss0, s_loss1 = sent_loss(cnn_code, sent_emb, match_labels,
                                         class_ids, args.batch_size)
            s_loss = (s_loss0 + s_loss1) * args.lamb

            # add to g_loss
            g_loss = g_loss + w_loss + s_loss

            # kl loss
            kl_loss = KL_loss(mu, logvar)

            # add to g_loss
            g_loss += kl_loss

            g_loss.backward()
            G_solver.step()

            # if i==3:
            #     # print the parameters of model
            #     params=G.state_dict()
            #     for k, v in params.items():
            #         print(k)
            #     print(params['deconv.0.weight'])
            #     assert False

        # =============== each epoch save model or save image =============== #
        if args.frame_motion_loss == True and args.video_loss == False:
            print('Epoch-{}; D_loss_frame:{:.4}, G_loss:{:.4}, D_lr:{:.4}'.
                  format(epoch, d_loss_frame.item(), g_loss.item(),
                         D_solver.state_dict()['param_groups'][0]["lr"]))
        elif args.frame_motion_loss == False and args.video_loss == True:
            print('Epoch-{}; D_loss_video:{:.4}, G_loss:{:.4}, D_lr:{:.4}'.
                  format(epoch, d_loss.item(), g_loss.item(),
                         D_solver.state_dict()['param_groups'][0]["lr"]))
        else:
            print(
                'Epoch-{}; D_loss_video:{:.4}, D_loss_frame:{:.4}, G_loss:{:.4}, D_lr:{:.4}'
                .format(epoch, d_loss.item(), d_loss_frame.item(),
                        g_loss.item(),
                        D_solver.state_dict()['param_groups'][0]["lr"]))

        # calculate the fid score
        fid_image_path = os.path.join(args.output_dir,
                                      args.fid_fake_foldername, "images")
        if not os.path.exists(fid_image_path):
            os.makedirs(fid_image_path)
        vutils.save_image_forFID(fake,
                                 '{}/fake_samples_epoch_{}_{}.png'.format(
                                     fid_image_path, epoch + 1, i),
                                 normalize=True,
                                 pad_value=1,
                                 input_channels=args.input_channels,
                                 imageSize=args.imageSize,
                                 fid_image_path=fid_image_path)

        # path_fid_images = [args.fid_real_path, fid_image_path]
        path_fid_images = [
            args.fid_real_path,
            os.path.join(args.output_dir, args.fid_fake_foldername)
        ]
        print('calculate the fid score ...')
        # fid_value = fid_score.calculate_fid_score(path=path_fid_images,
        #     batch_size=args.batch_size, gpu=str(args.gpu+args.gpu_num-1))
        try:
            fid_value = fid_score.calculate_fid_score(
                path=path_fid_images,
                batch_size=args.batch_size,
                gpu=str(args.gpu))
        except:
            fid_value = best_fid
        if fid_value < best_fid:
            best_fid = fid_value
            best_epoch = epoch
            pickle_save_path = os.path.join(args.output_dir, args.pickle_dir)
            if not os.path.exists(pickle_save_path):
                os.makedirs(pickle_save_path)
            torch.save(G.state_dict(),
                       '{}/G_best.pth'.format(pickle_save_path))
            torch.save(D.state_dict(),
                       '{}/D_best.pth'.format(pickle_save_path))
            torch.save(D_frame_motion.state_dict(),
                       '{}/D_frame_motion_best.pth'.format(pickle_save_path))

        print(
            "\033[1;31m current_epoch[{}] current_fid[{}] \033[0m \033[1;34m best_epoch[{}] best_fid[{}] \033[0m"
            .format(epoch, fid_value, best_epoch, best_fid))

        # save fid
        with open(os.path.join(args.output_dir, 'log_fid.txt'), 'a') as f:
            f.write(
                "current_epoch[{}] current_fid[{}] best_epoch[{}] best_fid[{}] \n"
                .format(epoch, fid_value, best_epoch, best_fid))

        # save images and sentence
        if (epoch + 1) % args.image_save_step == 0:
            image_path = os.path.join(args.output_dir, args.image_dir)
            if not os.path.exists(image_path):
                os.makedirs(image_path)
            vutils.save_image(fake,
                              '{}/fake_samples_epoch_{}_{}.png'.format(
                                  image_path, epoch + 1, i),
                              normalize=True,
                              pad_value=1,
                              input_channels=args.input_channels,
                              imageSize=args.imageSize)
            # with open('{}/{:0>4d}.txt'.format(image_path, epoch+1), 'a') as f:
            #     for s in range(len(sentences)):
            #         f.write(sentences[s])
            #         f.write('\n')
            with open('{}/{:0>4d}.txt'.format(image_path, epoch + 1),
                      'a') as f:
                for s in xrange(len(captions)):
                    for w in xrange(len(captions[s])):
                        idx = captions[s][w].item()
                        if idx == 0:
                            break
                        word = dataset.ixtoword[idx]
                        f.write(word)
                        f.write(' ')
                    f.write('\n')
        # # print the parameters of model
        # params=G.state_dict()
        # for k, v in params.items():
        #     print(k)
        # print(params['module.deconv.0.weight'])
        # assert False

        # checkpoint
        if (epoch + 1) % args.pickle_step == 0:
            pickle_save_path = os.path.join(args.output_dir, args.pickle_dir)
            if not os.path.exists(pickle_save_path):
                os.makedirs(pickle_save_path)
            torch.save(G.state_dict(),
                       '{}/G_epoch_{}.pth'.format(pickle_save_path, epoch + 1))
            torch.save(D.state_dict(),
                       '{}/D_epoch_{}.pth'.format(pickle_save_path, epoch + 1))
            torch.save(
                D_frame_motion.state_dict(),
                '{}/D_frame_motion_epoch_{}.pth'.format(
                    pickle_save_path, epoch + 1))
Пример #6
0
def train(args):
    #for creating the visdom object
    DEFAULT_PORT = 8097
    DEFAULT_HOSTNAME = "http://localhost"
    viz = Visdom(DEFAULT_HOSTNAME, DEFAULT_PORT, ipv6=False)

    hyparam_list = [
        ("model", args.model_name),
        ("cube", args.cube_len),
        ("bs", args.batch_size),
        ("g_lr", args.g_lr),
        ("d_lr", args.d_lr),
        ("z", args.z_dis),
        ("bias", args.bias),
        ("sl", args.soft_label),
    ]

    hyparam_dict = OrderedDict(((arg, value) for arg, value in hyparam_list))
    log_param = make_hyparam_string(hyparam_dict)
    print(log_param)

    # for using tensorboard
    if args.use_tensorboard:
        import tensorflow as tf

        summary_writer = tf.summary.FileWriter(args.output_dir + args.log_dir +
                                               log_param)

        def inject_summary(summary_writer, tag, value, step):
            summary = tf.Summary(
                value=[tf.Summary.Value(tag=tag, simple_value=value)])
            summary_writer.add_summary(summary, global_step=step)

        inject_summary = inject_summary

    # datset define
    dsets_path = args.input_dir + args.data_dir + "train/"
    print(dsets_path)

    x_train = np.load("voxels_3DMNIST_16.npy")
    dataset = x_train.reshape(-1,
                              args.cube_len * args.cube_len * args.cube_len)
    print(dataset.shape)
    dset_loaders = torch.utils.data.DataLoader(dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=1)

    # model define
    D = _D(args)
    G = _G(args)

    D_solver = optim.Adam(D.parameters(), lr=args.d_lr, betas=args.beta)
    G_solver = optim.Adam(G.parameters(), lr=args.g_lr, betas=args.beta)

    if torch.cuda.is_available():
        print("using cuda")
        D.cuda()
        G.cuda()

    criterion = nn.BCELoss()

    pickle_path = "." + args.pickle_dir + log_param
    read_pickle(pickle_path, G, G_solver, D, D_solver)

    for epoch in range(args.n_epochs):
        epoch_start_time = time.time()
        print("epoch %d started" % (epoch))
        for i, X in enumerate(dset_loaders):

            X = var_or_cuda(X)
            X = X.type(torch.cuda.FloatTensor)
            if X.size()[0] != int(args.batch_size):
                #print("batch_size != {} drop last incompatible batch".format(int(args.batch_size)))
                continue

            Z = generateZ(args)
            real_labels = var_or_cuda(torch.ones(args.batch_size)).view(
                -1, 1, 1, 1, 1)
            fake_labels = var_or_cuda(torch.zeros(args.batch_size)).view(
                -1, 1, 1, 1, 1)

            if args.soft_label:
                real_labels = var_or_cuda(
                    torch.Tensor(args.batch_size).uniform_(0.9, 1.1)).view(
                        -1, 1, 1, 1, 1)  ####
                #fake_labels = var_or_cuda(torch.Tensor(args.batch_size).uniform_(0, 0.3)).view(-1,1,1,1,1)
                fake_labels = var_or_cuda(torch.zeros(args.batch_size)).view(
                    -1, 1, 1, 1, 1)  #####
            # ============= Train the discriminator =============#
            d_real = D(X)
            d_real_loss = criterion(d_real, real_labels)

            fake = G(Z)
            d_fake = D(fake)
            d_fake_loss = criterion(d_fake, fake_labels)

            d_loss = d_real_loss + d_fake_loss

            d_real_acu = torch.ge(d_real.squeeze(), 0.5).float()
            d_fake_acu = torch.le(d_fake.squeeze(), 0.5).float()
            d_total_acu = torch.mean(torch.cat((d_real_acu, d_fake_acu), 0))

            #if 1:
            if d_total_acu <= args.d_thresh:
                D.zero_grad()
                d_loss.backward()
                D_solver.step()

            # =============== Train the generator ===============#

            Z = generateZ(args)

            fake = G(Z)
            d_fake = D(fake)
            g_loss = criterion(d_fake, real_labels)

            D.zero_grad()
            G.zero_grad()
            g_loss.backward()
            G_solver.step()
            #######
            #print(fake.shape)
            #print(fake.cpu().data[:8].squeeze().numpy().shape)

            # =============== logging each iteration ===============#
            iteration = str(G_solver.state_dict()['state'][
                G_solver.state_dict()['param_groups'][0]['params'][0]]['step'])
            #print(type(iteration))
            #iteration = str(i)
            #saving the model and a image each 100 iteration
            if int(iteration) % 300 == 0:
                #pickle_save_path = args.output_dir + args.pickle_dir + log_param
                #save_new_pickle(pickle_save_path, iteration, G, G_solver, D, D_solver)
                samples = fake.cpu().data[:8].squeeze().numpy()

                #print(samples.shape)
                for s in range(8):
                    plotVoxelVisdom(samples[s, ...], viz,
                                    "Iteration:{:.4}".format(iteration))

#                 image_path = args.output_dir + args.image_dir + log_param
#                 if not os.path.exists(image_path):
#                     os.makedirs(image_path)

#                 SavePloat_Voxels(samples, image_path, iteration)
# =============== each epoch save model or save image ===============#
            print(
                'Iter-{}; , D_loss : {:.4}, G_loss : {:.4}, D_acu : {:.4}, D_lr : {:.4}'
                .format(iteration, d_loss.item(), g_loss.item(),
                        d_total_acu.item(),
                        D_solver.state_dict()['param_groups'][0]["lr"]))

        epoch_end_time = time.time()

        if (epoch + 1) % args.image_save_step == 0:

            samples = fake.cpu().data[:8].squeeze().numpy()

            image_path = args.output_dir + args.image_dir + log_param
            if not os.path.exists(image_path):
                os.makedirs(image_path)

            SavePloat_Voxels(samples, image_path, iteration)

        if (epoch + 1) % args.pickle_step == 0:
            pickle_save_path = args.output_dir + args.pickle_dir + log_param
            save_new_pickle(pickle_save_path, iteration, G, G_solver, D,
                            D_solver)

        print("epoch time", (epoch_end_time - epoch_start_time) / 60)
        print("epoch %d ended" % (epoch))
        print("################################################")
def train(args):
    #WSGAN related params
    lambda_gp = 10
    n_critic = 5

    hyparam_list = [
        ("model", args.model_name),
        ("cube", args.cube_len),
        ("bs", args.batch_size),
        ("g_lr", args.g_lr),
        ("d_lr", args.d_lr),
        ("z", args.z_dis),
        ("bias", args.bias),
    ]

    hyparam_dict = OrderedDict(((arg, value) for arg, value in hyparam_list))
    log_param = make_hyparam_string(hyparam_dict)
    print(log_param)

    #define different paths
    pickle_path = "." + args.pickle_dir + log_param
    image_path = args.output_dir + args.image_dir + log_param
    pickle_save_path = args.output_dir + args.pickle_dir + log_param

    N = None  # None for the whole dataset
    VOL_SIZE = 64
    train_path = pathlib.Path("../Vert_dataset")
    dataset = VertDataset(train_path,
                          n=N,
                          transform=transforms.Compose(
                              [ResizeTo(VOL_SIZE),
                               transforms.ToTensor()]))
    print('Number of samples: ', len(dataset))
    dset_loaders = torch.utils.data.DataLoader(dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=0)
    print('Number of batches: ', len(dset_loaders))

    #  Build the model
    D = _D(args)
    G = _G(args)

    #Create the solvers
    D_solver = optim.Adam(D.parameters(), lr=args.d_lr, betas=args.beta)
    G_solver = optim.Adam(G.parameters(), lr=args.g_lr, betas=args.beta)

    if torch.cuda.device_count() > 1:
        D = nn.DataParallel(D)
        G = nn.DataParallel(G)
        print("Using {} GPUs".format(torch.cuda.device_count()))
        D.cuda()
        G.cuda()

    elif torch.cuda.is_available():
        print("using cuda")
        D.cuda()
        G.cuda()

    #Load checkpoint if available
    read_pickle(pickle_path, G, G_solver, D, D_solver)

    G_losses = []
    D_losses = []

    for epoch in range(args.n_epochs):
        epoch_start_time = time.time()
        print("epoch %d started" % (epoch))
        for i, X in enumerate(dset_loaders):
            #print(X.shape)
            X = X.view(-1, args.cube_len * args.cube_len * args.cube_len)
            X = var_or_cuda(X)
            X = X.type(torch.cuda.FloatTensor)
            Z = generateZ(num_samples=X.size(0), z_size=args.z_size)

            #Train the critic
            d_loss, Wasserstein_D, gp = train_critic(X, Z, D, G, D_solver,
                                                     G_solver)

            # Train the generator every n_critic steps
            if i % n_critic == 0:
                Z = generateZ(num_samples=X.size(0), z_size=args.z_size)
                g_loss = train_gen(Z, D, G, D_solver, G_solver)

            #Log each iteration
            iteration = str(G_solver.state_dict()['state'][
                G_solver.state_dict()['param_groups'][0]['params'][0]]['step'])
            print('Iter-{}; , D_loss : {:.4}, G_loss : {:.4}, WSdistance : {:.4}, GP : {:.4}'.format(iteration, d_loss.item(), \
                                                                            g_loss.item(), Wasserstein_D.item(), gp.item() ))
        ## End of epoch
        epoch_end_time = time.time()

        #Plot the losses each epoch
        G_losses.append(g_loss.item())
        D_losses.append(d_loss.item())
        plot_losess(G_losses, D_losses, epoch)

        if (epoch + 1) % args.image_save_step == 0:
            print("Saving voxels")
            Z = generateZ(num_samples=8, z_size=args.z_size)
            gen_output = G(Z)
            samples = gen_output.cpu().data[:8].squeeze().numpy()
            samples = samples.reshape(-1, args.cube_len, args.cube_len,
                                      args.cube_len)
            Save_Voxels(samples, image_path, iteration)

        if (epoch + 1) % args.pickle_step == 0:
            print("Pickeling the model")
            save_new_pickle(pickle_save_path, iteration, G, G_solver, D,
                            D_solver)

        print("epoch time", (epoch_end_time - epoch_start_time) / 60)
        print("epoch %d ended" % (epoch))
        print("################################################")
Пример #8
0
def test(args):
    # set devices
    torch.cuda.set_device(args.gpu)
    torch.backends.cudnn.benchmark = True
    device = torch.device("cuda:{}".format(args.gpu))

    # # datset define
    # print(args.dataroot)
    # dsets = dataset.ImageFolder(root=args.dataroot,
    #                             transform=transforms.Compose([
    #                                transforms.Resize(args.imageSize),
    #                                transforms.CenterCrop(args.imageSize),
    #                                transforms.ToTensor(),
    #                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    #                             ]),
    #                             text_path=args.text_path, args=args)
    # # dset_loaders = torch.utils.data.DataLoader(dsets, batch_size=args.batch_size, \
    # #     shuffle=True, num_workers=args.workers, drop_last=True, worker_init_fn=np.random.seed(manualSeed))
    # dset_loaders = torch.utils.data.DataLoader(dsets, batch_size=args.batch_size, \
    #     shuffle=True, num_workers=args.workers, drop_last=True)\

    # Get data loader ##################################################
    image_transform = transforms.Compose([transforms.Resize(args.imageSize)])
    dataset = TextDataset(args.dataroot,
                          'train',
                          base_size=args.imageSize,
                          transform=image_transform)
    dset_loaders = torch.utils.data.DataLoader(dataset,
                                               batch_size=args.batch_size,
                                               drop_last=True,
                                               shuffle=True,
                                               num_workers=args.workers)

    # ============== #
    # Define D and G #
    # ============== #
    # D = _D(args)
    # D_frame_motion = _D_frame_motion(args)
    G = _G(args)
    text_encoder = RNN_ENCODER(dataset.n_words,
                               nhidden=args.hidden_size,
                               batch_size=args.batch_size)

    if args.cuda:
        print("using cuda")
        if args.gpu_num >= 1:
            device_ids = range(args.gpu, args.gpu + args.gpu_num)
            # D = nn.DataParallel(D, device_ids = device_ids)
            # D_frame_motion = nn.DataParallel(D_frame_motion, device_ids = device_ids)
            G = nn.DataParallel(G, device_ids=device_ids)
            # text_encoder = nn.DataParallel(text_encoder, device_ids = device_ids)
            # image_encoder = nn.DataParallel(image_encoder, device_ids = device_ids)
        # D.cuda()
        # D_frame_motion.cuda()
        G.cuda()
        text_encoder.cuda()
        # image_encoder.cuda()
        # criterion = criterion.cuda()
        # criterion_l1 = criterion_l1.cuda()
        # criterion_l2 = criterion_l2.cuda()

    # if args.checkpoint_D != '':
    #     D.load_state_dict(torch.load(args.checkpoint_D))
    # if args.checkpoint_frame_motion_D != '':
    #     D_frame_motion.load_state_dict(torch.load(args.checkpoint_frame_motion_D))
    if args.checkpoint_G != '':
        G.load_state_dict(torch.load(args.checkpoint_G))

    # ================================================== #
    # Load text and image encoder and fix the parameters #
    # ================================================== #
    # text_encoder = RNN_ENCODER(dataset.n_words, nhidden=args.hidden_size, batch_size=args.batch_size)
    # image_encoder = CNN_ENCODER(args.hidden_size)

    if args.checkpoint_text_encoder != '':
        text_encoder.load_state_dict(torch.load(args.checkpoint_text_encoder))
    # if args.checkpoint_image_encoder != '':
    #     image_encoder.load_state_dict(torch.load(args.checkpoint_image_encoder))

    for p in text_encoder.parameters():
        p.requires_grad = False
    print('Load text encoder from: %s' % args.checkpoint_text_encoder)
    text_encoder.eval()

    # for p in image_encoder.parameters():
    #     p.requires_grad = False
    # print('Load image encoder from: %s'%args.checkpoint_image_encoder)
    # image_encoder.eval()

    # criterion of update
    # D_solver = optim.Adam(D.parameters(), lr=args.d_lr, betas=args.beta)
    # D_solver_frame = optim.Adam(D_frame_motion.parameters(), lr=args.d_lr, betas=args.beta)
    G_solver = optim.Adam(G.parameters(), lr=args.g_lr, betas=args.beta)

    # # loss
    # criterion = nn.BCELoss()
    # criterion_l1 = nn.L1Loss()
    # # criterion_l2 = nn.MSELoss(size_average=False)

    # if args.cuda:
    #     print("using cuda")
    #     if args.gpu_num>1:
    #         device_ids = range(args.gpu, args.gpu+args.gpu_num)
    #         # D = nn.DataParallel(D, device_ids = device_ids)
    #         # D_frame_motion = nn.DataParallel(D_frame_motion, device_ids = device_ids)
    #         G = nn.DataParallel(G, device_ids = device_ids)
    #         # text_encoder = nn.DataParallel(text_encoder, device_ids = device_ids)
    #         # image_encoder = nn.DataParallel(image_encoder, device_ids = device_ids)
    #     # D.cuda()
    #     # D_frame_motion.cuda()
    #     G.cuda()
    #     text_encoder.cuda()
    #     # image_encoder.cuda()
    #     # criterion = criterion.cuda()
    #     # criterion_l1 = criterion_l1.cuda()
    #     # criterion_l2 = criterion_l2.cuda()

    # # print the parameters of model
    # params=G.state_dict()
    # for k, v in params.items():
    #     print(k)
    # print(params['deconv.0.weight'])

    if args.cls:
        print("using cls")
    if args.A:
        print("using A")
    # if args.C:
    #     print("using C")
    if args.video_loss:
        print("using video discriminator")
    if args.frame_motion_loss:
        print("using frame and motion discriminator")

    # if not args.simpleEmb:
    #     # text 2 embedding
    #     text2emb = Text2Embedding(args, device)

    # # text 2 embedding
    # text2emb = Text2Embedding(args, device)

    # estimate time
    start = time.time()

    # # video labels
    # if args.soft_label:
    #     if args.cuda:
    #         real_labels = torch.Tensor(args.batch_size).uniform_(0.7, 1.2).cuda()
    #         fake_labels = torch.Tensor(args.batch_size).uniform_(0, 0.3).cuda()
    #     else:
    #         real_labels = torch.Tensor(args.batch_size).uniform_(0.7, 1.2)
    #         fake_labels = torch.Tensor(args.batch_size).uniform_(0, 0.3)
    # else:
    #     if args.cuda:
    #         real_labels = torch.ones(args.batch_size).cuda()
    #         fake_labels = torch.zeros(args.batch_size).cuda()
    #     else:
    #         real_labels = torch.ones(args.batch_size)
    #         fake_labels = torch.zeros(args.batch_size)

    # # frame labels
    # if args.soft_label:
    #     if args.cuda:
    #         real_labels_frame = torch.Tensor(args.batch_size*args.frame_num).uniform_(0.7, 1.2).cuda()
    #         fake_labels_frame = torch.Tensor(args.batch_size*args.frame_num).uniform_(0, 0.3).cuda()
    #     else:
    #         real_labels_frame = torch.Tensor(args.batch_size*args.frame_num).uniform_(0.7, 1.2)
    #         fake_labels_frame = torch.Tensor(args.batch_size*args.frame_num).uniform_(0, 0.3)
    # else:
    #     if args.cuda:
    #         real_labels_frame = torch.ones(args.batch_size*args.frame_num).cuda()
    #         fake_labels_frame = torch.zeros(args.batch_size*args.frame_num).cuda()
    #     else:
    #         real_labels_frame = torch.ones(args.batch_size*args.frame_num)
    #         fake_labels_frame = torch.zeros(args.batch_size*args.frame_num)

    # # motion labels
    # if args.A:
    #     if args.soft_label:
    #         if args.cuda:
    #             real_labels_motion = torch.Tensor(args.batch_size*(args.frame_num-1)).uniform_(0.7, 1.2).cuda()
    #             fake_labels_motion = torch.Tensor(args.batch_size*(args.frame_num-1)).uniform_(0, 0.3).cuda()
    #         else:
    #             real_labels_motion = torch.Tensor(args.batch_size*(args.frame_num-1)).uniform_(0.7, 1.2)
    #             fake_labels_motion = torch.Tensor(args.batch_size*(args.frame_num-1)).uniform_(0, 0.3)
    #     else:
    #         if args.cuda:
    #             real_labels_motion = torch.ones(args.batch_size*(args.frame_num-1)).cuda()
    #             fake_labels_motion = torch.zeros(args.batch_size*(args.frame_num-1)).cuda()
    #         else:
    #             real_labels_motion = torch.ones(args.batch_size*(args.frame_num-1))
    #             fake_labels_motion = torch.zeros(args.batch_size*(args.frame_num-1))

    # # matching labels
    # if args.cuda:
    #     match_labels = torch.LongTensor(range(args.batch_size)).cuda()
    # else:
    #     match_labels = torch.LongTensor(range(args.batch_size))

    # iteration
    for epoch in range(args.n_epochs):
        for i, data in enumerate(dset_loaders, 0):
            if i % 10 == 0:
                print('Epoch[{}][{}/{}] Time: {}'.format(
                    epoch, i, len(dset_loaders),
                    timeSince(
                        start,
                        float(epoch * len(dset_loaders) + i) /
                        float(args.n_epochs * len(dset_loaders)))))

            # if (i+1)%5==0:
            #     break

            # ========================= #
            #  Prepare training data    #
            # ========================= #

            # X, X_wrong, _, sentences = data

            # X = X.cuda()
            # X_wrong = X_wrong.cuda()

            # # get the embedding of sentence (size: batch_size x 256)
            # # embedding = None
            # for j in range(len(sentences)):
            #     sentence = sentences[j]
            #     encoder_output, _ = text2emb.text2embedding(sentence, args.max_length)
            #     if j==0:
            #         embedding = torch.squeeze(encoder_output, 0)
            #     else:
            #         embedding = torch.cat((embedding, torch.squeeze(encoder_output, 0)), 0)

            # X, captions, cap_lens, class_ids, keys = prepare_data(data)
            # X = X[0]

            X, captions, cap_lens, class_ids, keys, \
            X_wrong, caps_wrong, cap_len_wrong, cls_id_wrong, key_wrong \
            = prepare_data_real_wrong(data)
            X = X[0]
            X_wrong = X_wrong[0]

            # if args.gpu_num > 1:
            #     hidden = text_encoder.module.init_hidden()
            # else:
            #     hidden = text_encoder.init_hidden()

            hidden = text_encoder.init_hidden()

            # words_embs: batch_size x nef x seq_len
            # sent_emb: batch_size x nef
            words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
            words_embs, sent_emb = words_embs.detach(), sent_emb.detach()

            # ==================== #
            # Generate fake images #
            # ==================== #
            # generate input noize Z (size: batch_size x 100)
            Z = generateZ(args)
            # Z = torch.cat([Z, embedding], 1)

            # print(Z.shape)
            # print(sent_emb.shape)
            # assert False

            fake, mu, logvar = G(Z, sent_emb)

            # save images and sentence
            if (epoch + 1) % args.image_save_step == 0:
                image_path = os.path.join(args.output_dir, args.image_dir)
                if not os.path.exists(image_path):
                    os.makedirs(image_path)
                vutils.save_image(
                    fake,
                    '{}/fake_samples_epoch_{:0>5d}_{:0>5d}.png'.format(
                        image_path, epoch + 1, i),
                    normalize=True,
                    pad_value=1,
                    single_channel=True,
                    imageSize=args.imageSize)
                # with open('{}/{:0>4d}.txt'.format(image_path, epoch+1), 'a') as f:
                #     for s in range(len(sentences)):
                #         f.write(sentences[s])
                #         f.write('\n')
                # with open('{}/{:0>4d}.txt'.format(image_path, epoch+1), 'a') as f:
                with open(
                        '{}/{:0>4d}_{:0>4d}.txt'.format(
                            image_path, epoch, i + 1), 'a') as f:
                    for s in xrange(len(captions)):
                        for w in xrange(len(captions[s])):
                            idx = captions[s][w].item()
                            if idx == 0:
                                break
                            word = dataset.ixtoword[idx]
                            f.write(word)
                            f.write(' ')
                        f.write('\n')
Пример #9
0
def test(args):
    # set devices
    torch.cuda.set_device(args.gpu)
    torch.backends.cudnn.benchmark = True
    device = torch.device("cuda:{}".format(args.gpu))

    # Get data loader ##################################################
    image_transform = transforms.Compose([transforms.Resize(args.imageSize)])
    dataset = TextDataset(args.dataroot,
                          'train',
                          base_size=args.imageSize,
                          transform=image_transform,
                          input_channels=args.input_channels,
                          image_type=args.image_type)
    dset_loaders = torch.utils.data.DataLoader(dataset,
                                               batch_size=args.batch_size,
                                               drop_last=True,
                                               shuffle=True,
                                               num_workers=args.workers)

    # ============== #
    # Define D and G #
    # ============== #
    # D = _D(args)
    # D_frame_motion = _D_frame_motion(args)
    G = _G(args)
    text_encoder = RNN_ENCODER(dataset.n_words,
                               nhidden=args.hidden_size,
                               batch_size=args.batch_size)

    if args.cuda:
        print("using cuda")
        if args.gpu_num > 1:
            device_ids = range(args.gpu, args.gpu + args.gpu_num)
            # D = nn.DataParallel(D, device_ids = device_ids)
            # D_frame_motion = nn.DataParallel(D_frame_motion, device_ids = device_ids)
            G = nn.DataParallel(G, device_ids=device_ids)
            # text_encoder = nn.DataParallel(text_encoder, device_ids = device_ids)
            # image_encoder = nn.DataParallel(image_encoder, device_ids = device_ids)
        # D.cuda()
        # D_frame_motion.cuda()
        G.cuda()
        text_encoder.cuda()
        # image_encoder.cuda()
        # criterion = criterion.cuda()
        # criterion_l1 = criterion_l1.cuda()
        # criterion_l2 = criterion_l2.cuda()

    if args.checkpoint_G != '':
        G.load_state_dict(torch.load(args.checkpoint_G, map_location='cpu'))

    # ================================================== #
    # Load text and image encoder and fix the parameters #
    # ================================================== #
    # text_encoder = RNN_ENCODER(dataset.n_words, nhidden=args.hidden_size, batch_size=args.batch_size)
    # image_encoder = CNN_ENCODER(args.hidden_size)

    if args.checkpoint_text_encoder != '':
        text_encoder.load_state_dict(torch.load(args.checkpoint_text_encoder))
    # if args.checkpoint_image_encoder != '':
    #     image_encoder.load_state_dict(torch.load(args.checkpoint_image_encoder))

    for p in text_encoder.parameters():
        p.requires_grad = False
    print('Load text encoder from: %s' % args.checkpoint_text_encoder)

    text_encoder.eval()
    G.eval()

    # criterion of update
    # D_solver = optim.Adam(D.parameters(), lr=args.d_lr, betas=args.beta)
    # D_solver_frame = optim.Adam(D_frame_motion.parameters(), lr=args.d_lr, betas=args.beta)
    G_solver = optim.Adam(G.parameters(), lr=args.g_lr, betas=args.beta)

    if args.cls:
        print("using cls")
    if args.A:
        print("using A")
    # if args.C:
    #     print("using C")
    if args.video_loss:
        print("using video discriminator")
    if args.frame_motion_loss:
        print("using frame and motion discriminator")

    # estimate time
    start = time.time()

    counter_folder = 0
    # iteration
    for epoch in range(args.n_epochs):
        for i, data in enumerate(dset_loaders, 0):
            if i % 10 == 0:
                print('Epoch[{}][{}/{}] Time: {}'.format(
                    epoch, i, len(dset_loaders),
                    timeSince(
                        start,
                        float(epoch * len(dset_loaders) + i) /
                        float(args.n_epochs * len(dset_loaders)))))

            # if (i+1)%5==0:
            #     break

            X, captions, cap_lens, class_ids, keys, \
            X_wrong, caps_wrong, cap_len_wrong, cls_id_wrong, key_wrong \
            = prepare_data_real_wrong(data)
            X = X[0]
            X_wrong = X_wrong[0]

            hidden = text_encoder.init_hidden()

            # words_embs: batch_size x nef x seq_len
            # sent_emb: batch_size x nef
            words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
            words_embs, sent_emb = words_embs.detach(), sent_emb.detach()

            # ==================== #
            # Generate fake images #
            # ==================== #
            # generate input noize Z (size: batch_size x 100)
            Z = generateZ(args)
            # Z = torch.cat([Z, embedding], 1)

            fake, mu, logvar = G(Z, sent_emb)

            # save images and sentence
            if (epoch + 1) % args.image_save_step == 0:
                fid_image_path = os.path.join(args.output_dir,
                                              args.fid_fake_foldername,
                                              "images")
                if not os.path.exists(fid_image_path):
                    os.makedirs(fid_image_path)
                counter_folder = vutils.save_image_forfinalFID(
                    fake,
                    None,
                    normalize=True,
                    pad_value=1,
                    input_channels=args.input_channels,
                    imageSize=args.imageSize,
                    fid_image_path=fid_image_path,
                    counter_folder=counter_folder)