Beispiel #1
0
def train():
    param = _param()
    dataset = LoadDataset_NAB(opt)

    data_layer = FeatDataLayer(dataset.labels_train,
                               dataset.pfc_feat_data_train, opt)
    result = Result()

    netG = _netG(dataset.text_dim, dataset.feature_dim).cuda()
    netG.apply(weights_init)
    print(netG)
    netD = _netD(dataset.train_cls_num, dataset.feature_dim).cuda()
    netD.apply(weights_init)
    print(netD)

    exp_info = 'NAB_EASY' if opt.splitmode == 'easy' else 'NAB_HARD'
    exp_params = 'Eu{}_Rls{}_RWz{}'.format(opt.CENT_LAMBDA, opt.REG_W_LAMBDA,
                                           opt.REG_Wz_LAMBDA)

    out_dir = 'out/{:s}'.format(exp_info)
    out_subdir = 'out/{:s}/{:s}'.format(exp_info, exp_params)
    if not os.path.exists('out'):
        os.mkdir('out')
    if not os.path.exists(out_dir):
        os.mkdir(out_dir)
    if not os.path.exists(out_subdir):
        os.mkdir(out_subdir)

    cprint(" The output dictionary is {}".format(out_subdir), 'red')
    log_dir = out_subdir + '/log_{:s}.txt'.format(exp_info)
    with open(log_dir, 'a') as f:
        f.write('Training Start:')
        f.write(strftime("%a, %d %b %Y %H:%M:%S +0000", gmtime()) + '\n')

    start_step = 0

    if opt.resume:
        if os.path.isfile(opt.resume):
            print("=> loading checkpoint '{}'".format(opt.resume))
            checkpoint = torch.load(opt.resume)
            netG.load_state_dict(checkpoint['state_dict_G'])
            netD.load_state_dict(checkpoint['state_dict_D'])
            start_step = checkpoint['it']
            print(checkpoint['log'])
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))

    nets = [netG, netD]

    tr_cls_centroid = Variable(
        torch.from_numpy(dataset.tr_cls_centroid.astype('float32'))).cuda()
    optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(0.5, 0.9))
    optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(0.5, 0.9))

    for it in range(start_step, 3000 + 1):
        """ Discriminator """
        for _ in range(5):
            blobs = data_layer.forward()
            feat_data = blobs['data']  # image data
            labels = blobs['labels'].astype(int)  # class labels
            text_feat = np.array(
                [dataset.train_text_feature[i, :] for i in labels])
            text_feat = Variable(torch.from_numpy(
                text_feat.astype('float32'))).cuda()
            X = Variable(torch.from_numpy(feat_data)).cuda()
            y_true = Variable(torch.from_numpy(labels.astype('int'))).cuda()
            z = Variable(torch.randn(opt.batchsize, param.z_dim)).cuda()

            # GAN's D loss
            D_real, C_real = netD(X)
            D_loss_real = torch.mean(D_real)
            C_loss_real = F.cross_entropy(C_real, y_true)
            DC_loss = -D_loss_real + C_loss_real
            DC_loss.backward()

            # GAN's D loss
            G_sample = netG(z, text_feat).detach()
            D_fake, C_fake = netD(G_sample)
            D_loss_fake = torch.mean(D_fake)
            C_loss_fake = F.cross_entropy(C_fake, y_true)
            DC_loss = D_loss_fake + C_loss_fake
            DC_loss.backward()

            # train with gradient penalty (WGAN_GP)
            grad_penalty = calc_gradient_penalty(netD, X.data, G_sample.data)
            grad_penalty.backward()

            Wasserstein_D = D_loss_real - D_loss_fake
            optimizerD.step()
            reset_grad(nets)
        """ Generator """
        for _ in range(1):
            blobs = data_layer.forward()
            feat_data = blobs['data']  # image data
            labels = blobs['labels'].astype(int)  # class labels
            text_feat = np.array(
                [dataset.train_text_feature[i, :] for i in labels])
            text_feat = Variable(torch.from_numpy(
                text_feat.astype('float32'))).cuda()

            X = Variable(torch.from_numpy(feat_data)).cuda()
            y_true = Variable(torch.from_numpy(labels.astype('int'))).cuda()
            z = Variable(torch.randn(opt.batchsize, param.z_dim)).cuda()

            G_sample = netG(z, text_feat)
            D_fake, C_fake = netD(G_sample)
            _, C_real = netD(X)

            # GAN's G loss
            G_loss = torch.mean(D_fake)
            # Auxiliary classification loss
            C_loss = (F.cross_entropy(C_real, y_true) +
                      F.cross_entropy(C_fake, y_true)) / 2

            GC_loss = -G_loss + C_loss

            # Centroid loss
            Euclidean_loss = Variable(torch.Tensor([0.0])).cuda()
            if opt.REG_W_LAMBDA != 0:
                for i in range(dataset.train_cls_num):
                    sample_idx = (y_true == i).data.nonzero().squeeze()
                    if sample_idx.numel() == 0:
                        Euclidean_loss += 0.0
                    else:
                        G_sample_cls = G_sample[sample_idx, :]
                        Euclidean_loss += (
                            G_sample_cls.mean(dim=0) -
                            tr_cls_centroid[i]).pow(2).sum().sqrt()
                Euclidean_loss *= 1.0 / dataset.train_cls_num * opt.CENT_LAMBDA

            # ||W||_2 regularization
            reg_loss = Variable(torch.Tensor([0.0])).cuda()
            if opt.REG_W_LAMBDA != 0:
                for name, p in netG.named_parameters():
                    if 'weight' in name:
                        reg_loss += p.pow(2).sum()
                reg_loss.mul_(opt.REG_W_LAMBDA)

            # ||W_z||21 regularization, make W_z sparse
            reg_Wz_loss = Variable(torch.Tensor([0.0])).cuda()
            if opt.REG_Wz_LAMBDA != 0:
                Wz = netG.rdc_text.weight
                reg_Wz_loss = Wz.pow(2).sum(dim=0).sqrt().sum().mul(
                    opt.REG_Wz_LAMBDA)

            all_loss = GC_loss + Euclidean_loss + reg_loss + reg_Wz_loss
            all_loss.backward()
            optimizerG.step()
            reset_grad(nets)

        if it % opt.disp_interval == 0 and it:
            acc_real = (np.argmax(C_real.data.cpu().numpy(), axis=1)
                        == y_true.data.cpu().numpy()).sum() / float(
                            y_true.data.size()[0])
            acc_fake = (np.argmax(C_fake.data.cpu().numpy(), axis=1)
                        == y_true.data.cpu().numpy()).sum() / float(
                            y_true.data.size()[0])

            log_text = 'Iter-{}; Was_D: {:.4}; Euc_ls: {:.4}; reg_ls: {:.4}; Wz_ls: {:.4}; G_loss: {:.4}; D_loss_real: {:.4};' \
                       ' D_loss_fake: {:.4}; rl: {:.4}%; fk: {:.4}%'\
                        .format(it, Wasserstein_D.data[0],  Euclidean_loss.data[0], reg_loss.data[0],reg_Wz_loss.data[0],
                                G_loss.data[0], D_loss_real.data[0], D_loss_fake.data[0], acc_real * 100, acc_fake * 100)
            print(log_text)
            with open(log_dir, 'a') as f:
                f.write(log_text + '\n')

        if it % opt.evl_interval == 0 and it >= 100:
            netG.eval()
            eval_fakefeat_test(it, netG, dataset, param, result)
            if result.save_model:
                files2remove = glob.glob(out_subdir + '/Best_model*')
                for _i in files2remove:
                    os.remove(_i)
                torch.save(
                    {
                        'it': it + 1,
                        'state_dict_G': netG.state_dict(),
                        'state_dict_D': netD.state_dict(),
                        'random_seed': opt.manualSeed,
                        'log': log_text,
                    }, out_subdir +
                    '/Best_model_Acc_{:.2f}.tar'.format(result.acc_list[-1]))
            netG.train()

        if it % opt.save_interval == 0 and it:
            torch.save(
                {
                    'it': it + 1,
                    'state_dict_G': netG.state_dict(),
                    'state_dict_D': netD.state_dict(),
                    'random_seed': opt.manualSeed,
                    'log': log_text,
                }, out_subdir + '/Iter_{:d}.tar'.format(it))
            cprint('Save model to ' + out_subdir + '/Iter_{:d}.tar'.format(it),
                   'red')
Beispiel #2
0
def train(creative_weight=1000, model_num=1, is_val=True):
    param = _param()
    if opt.dataset == 'CUB':
        dataset = LoadDataset(opt, main_dir, is_val)
        exp_info = 'CUB_EASY' if opt.splitmode == 'easy' else 'CUB_HARD'
    elif opt.dataset == 'NAB':
        dataset = LoadDataset_NAB(opt, main_dir, is_val)
        exp_info = 'NAB_EASY' if opt.splitmode == 'easy' else 'NAB_HARD'
    else:
        print('No Dataset with that name')
        sys.exit(0)
    param.X_dim = dataset.feature_dim
    opt.Creative_weight = creative_weight

    data_layer = FeatDataLayer(dataset.labels_train,
                               dataset.pfc_feat_data_train, opt)
    result = Result()

    ones = Variable(torch.Tensor(1, 1))
    ones.data.fill_(1.0)

    netG = _netG(dataset.text_dim, dataset.feature_dim).cuda()
    netG.apply(weights_init)
    if model_num == 6:
        netD = _netD(dataset.train_cls_num + 1, dataset.feature_dim).cuda()
    else:
        netD = _netD(dataset.train_cls_num, dataset.feature_dim).cuda()
    netD.apply(weights_init)

    if model_num == 2:
        log_SM_ab = Scale(2)
        log_SM_ab = nn.DataParallel(log_SM_ab).cuda()
    elif model_num == 4 or model_num == 5:
        log_SM_ab = Scale(1)
        log_SM_ab = nn.DataParallel(log_SM_ab).cuda()

    exp_params = 'Model_{}_CAN{}_Eu{}_Rls{}_RWz{}_{}'.format(
        model_num, opt.Creative_weight, opt.CENT_LAMBDA, opt.REG_W_LAMBDA,
        opt.REG_Wz_LAMBDA, opt.exp_name)

    # out_subdir = main_dir + 'out/{:s}/{:s}'.format(exp_info, exp_params)
    out_subdir = main_dir + 'out/cizsl-reproduce-3/{:s}/{:s}'.format(
        exp_info, exp_params)
    if not os.path.exists(out_subdir):
        os.makedirs(out_subdir)

    log_dir = out_subdir + '/log_{:s}.txt'.format(exp_info)
    with open(log_dir, 'a') as f:
        f.write('Training Start:')
        f.write(strftime("%a, %d %b %Y %H:%M:%S +0000", gmtime()) + '\n')

    start_step = 0

    if opt.resume:
        if os.path.isfile(opt.resume):
            print("=> loading checkpoint '{}'".format(opt.resume))
            checkpoint = torch.load(opt.resume)
            netG.load_state_dict(checkpoint['state_dict_G'])
            netD.load_state_dict(checkpoint['state_dict_D'])
            start_step = checkpoint['it']
            print(checkpoint['log'])
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))

    if model_num == 2 or model_num == 4 or model_num == 5:
        nets = [netG, netD, log_SM_ab]
    else:
        nets = [netG, netD]

    tr_cls_centroid = Variable(
        torch.from_numpy(dataset.tr_cls_centroid.astype('float32'))).cuda()
    optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(0.5, 0.9))
    optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(0.5, 0.9))
    if model_num == 2 or model_num == 4 or model_num == 5:
        optimizer_SM_ab = optim.Adam(log_SM_ab.parameters(),
                                     lr=opt.lr,
                                     betas=(0.5, 0.999))

    for it in tqdm(range(start_step, 3000 + 1)):
        # Creative Loss
        blobs = data_layer.forward()
        labels = blobs['labels'].astype(int)
        new_class_labels = Variable(
            torch.from_numpy(np.ones_like(labels) *
                             dataset.train_cls_num)).cuda()
        text_feat_1 = np.array(
            [dataset.train_text_feature[i, :] for i in labels])
        text_feat_2 = np.array(
            [dataset.train_text_feature[i, :] for i in labels])
        np.random.shuffle(
            text_feat_1
        )  # Shuffle both features to guarantee different permutations
        np.random.shuffle(text_feat_2)
        alpha = (np.random.random(len(labels)) * (.8 - .2)) + .2

        text_feat_mean = np.multiply(alpha, text_feat_1.transpose())
        text_feat_mean += np.multiply(1. - alpha, text_feat_2.transpose())
        text_feat_mean = text_feat_mean.transpose()
        text_feat_mean = normalize(text_feat_mean, norm='l2', axis=1)
        text_feat_Creative = Variable(
            torch.from_numpy(text_feat_mean.astype('float32'))).cuda()
        z_creative = Variable(torch.randn(opt.batchsize, param.z_dim)).cuda()
        G_creative_sample = netG(z_creative, text_feat_Creative)
        """ Discriminator """
        for _ in range(5):
            blobs = data_layer.forward()
            feat_data = blobs['data']  # image data
            labels = blobs['labels'].astype(int)  # class labels

            text_feat = np.array(
                [dataset.train_text_feature[i, :] for i in labels])
            text_feat = Variable(torch.from_numpy(
                text_feat.astype('float32'))).cuda()
            X = Variable(torch.from_numpy(feat_data)).cuda()
            y_true = Variable(torch.from_numpy(labels.astype('int'))).cuda()
            z = Variable(torch.randn(opt.batchsize, param.z_dim)).cuda()

            # GAN's D loss
            D_real, C_real = netD(X)
            D_loss_real = torch.mean(D_real)
            C_loss_real = F.cross_entropy(C_real, y_true)
            DC_loss = -D_loss_real + C_loss_real
            DC_loss.backward()

            # GAN's D loss
            G_sample = netG(z, text_feat).detach()
            D_fake, C_fake = netD(G_sample)
            D_loss_fake = torch.mean(D_fake)
            C_loss_fake = F.cross_entropy(C_fake, y_true)

            DC_loss = D_loss_fake + C_loss_fake
            DC_loss.backward()

            # train with gradient penalty (WGAN_GP)
            grad_penalty = calc_gradient_penalty(netD, X.data, G_sample.data)
            grad_penalty.backward()

            Wasserstein_D = D_loss_real - D_loss_fake
            optimizerD.step()
            reset_grad(nets)
        """ Generator """
        for _ in range(1):
            blobs = data_layer.forward()
            feat_data = blobs['data']  # image data
            labels = blobs['labels'].astype(int)  # class labels
            text_feat = np.array(
                [dataset.train_text_feature[i, :] for i in labels])
            text_feat = Variable(torch.from_numpy(
                text_feat.astype('float32'))).cuda()

            X = Variable(torch.from_numpy(feat_data)).cuda()
            y_true = Variable(torch.from_numpy(labels.astype('int'))).cuda()
            z = Variable(torch.randn(opt.batchsize, param.z_dim)).cuda()

            G_sample = netG(z, text_feat)
            D_fake, C_fake = netD(G_sample)
            _, C_real = netD(X)

            # GAN's G loss
            G_loss = torch.mean(D_fake)
            # Auxiliary classification loss
            C_loss = (F.cross_entropy(C_real, y_true) +
                      F.cross_entropy(C_fake, y_true)) / 2

            GC_loss = -G_loss + C_loss

            # Centroid loss
            Euclidean_loss = Variable(torch.Tensor([0.0])).cuda()
            if opt.REG_W_LAMBDA != 0:
                for i in range(dataset.train_cls_num):
                    sample_idx = (y_true == i).data.nonzero().squeeze()
                    if sample_idx.numel() == 0:
                        Euclidean_loss += 0.0
                    else:
                        G_sample_cls = G_sample[sample_idx, :]
                        Euclidean_loss += (
                            G_sample_cls.mean(dim=0) -
                            tr_cls_centroid[i]).pow(2).sum().sqrt()
                Euclidean_loss *= 1.0 / dataset.train_cls_num * opt.CENT_LAMBDA

            # ||W||_2 regularization
            reg_loss = Variable(torch.Tensor([0.0])).cuda()
            if opt.REG_W_LAMBDA != 0:
                for name, p in netG.named_parameters():
                    if 'weight' in name:
                        reg_loss += p.pow(2).sum()
                reg_loss.mul_(opt.REG_W_LAMBDA)

            # ||W_z||21 regularization, make W_z sparse
            reg_Wz_loss = Variable(torch.Tensor([0.0])).cuda()
            if opt.REG_Wz_LAMBDA != 0:
                Wz = netG.rdc_text.weight
                reg_Wz_loss = Wz.pow(2).sum(dim=0).sqrt().sum().mul(
                    opt.REG_Wz_LAMBDA)

            # D(C| GX_fake)) + Classify GX_fake as real
            D_creative_fake, C_creative_fake = netD(G_creative_sample)
            if model_num == 1:  # KL Divergence
                G_fake_C = F.log_softmax(C_creative_fake)
            else:
                G_fake_C = F.softmax(C_creative_fake)

            if model_num == 1:  # KL Divergence
                entropy_GX_fake = (G_fake_C / G_fake_C.data.size(1)).mean()
            elif model_num == 2:  # SM Divergence
                q_shape = Variable(
                    torch.FloatTensor(G_fake_C.data.size(0),
                                      G_fake_C.data.size(1))).cuda()
                q_shape.data.fill_(1.0 / G_fake_C.data.size(1))

                SM_ab = F.sigmoid(log_SM_ab(ones))
                SM_a = 0.2 + torch.div(SM_ab[0][0], 1.6666666666666667).cuda()
                SM_b = 0.2 + torch.div(SM_ab[0][1], 1.6666666666666667).cuda()
                pow_a_b = torch.div(1 - SM_a, 1 - SM_b)
                alpha_term = (torch.pow(G_fake_C + 1e-5, SM_a) *
                              torch.pow(q_shape, 1 - SM_a)).sum(1)
                entropy_GX_fake_vec = torch.div(
                    torch.pow(alpha_term, pow_a_b) - 1, SM_b - 1)
            elif model_num == 3:  # Bachatera Divergence
                q_shape = Variable(
                    torch.FloatTensor(G_fake_C.data.size(0),
                                      G_fake_C.data.size(1))).cuda()
                q_shape.data.fill_(1.0 / G_fake_C.data.size(1))
                SM_a = Variable(torch.FloatTensor(1, 1)).cuda()
                SM_a.data.fill_(opt.SM_Alpha)
                SM_b = Variable(torch.FloatTensor(1, 1)).cuda()
                SM_b.data.fill_(opt.SM_Alpha)
                pow_a_b = torch.div(1 - SM_a, 1 - SM_b)
                alpha_term = (torch.pow(G_fake_C + 1e-5, SM_a) *
                              torch.pow(q_shape, 1 - SM_a)).sum(1)
                entropy_GX_fake_vec = -torch.div(
                    torch.pow(alpha_term, pow_a_b) - 1, SM_b - 1)
            elif model_num == 4:  # Tsallis Divergence
                q_shape = Variable(
                    torch.FloatTensor(G_fake_C.data.size(0),
                                      G_fake_C.data.size(1))).cuda()
                q_shape.data.fill_(1.0 / G_fake_C.data.size(1))

                SM_ab = F.sigmoid(log_SM_ab(ones))
                SM_a = 0.2 + torch.div(SM_ab[0][0], 1.6666666666666667).cuda()
                SM_b = SM_a
                pow_a_b = torch.div(1 - SM_a, 1 - SM_b)
                alpha_term = (torch.pow(G_fake_C + 1e-5, SM_a) *
                              torch.pow(q_shape, 1 - SM_a)).sum(1)
                entropy_GX_fake_vec = -torch.div(
                    torch.pow(alpha_term, pow_a_b) - 1, SM_b - 1)
            elif model_num == 5:  # Renyi Divergence
                q_shape = Variable(
                    torch.FloatTensor(G_fake_C.data.size(0),
                                      G_fake_C.data.size(1))).cuda()
                q_shape.data.fill_(1.0 / G_fake_C.data.size(1))

                SM_ab = F.sigmoid(log_SM_ab(ones))
                SM_a = 0.2 + torch.div(SM_ab[0][0], 1.6666666666666667).cuda()
                SM_b = Variable(torch.FloatTensor(1, 1)).cuda()
                SM_b.data.fill_(opt.SM_Beta)
                pow_a_b = torch.div(1 - SM_a, 1 - SM_b)
                alpha_term = (torch.pow(G_fake_C + 1e-5, SM_a) *
                              torch.pow(q_shape, 1 - SM_a)).sum(1)
                entropy_GX_fake_vec = -torch.div(
                    torch.pow(alpha_term, pow_a_b) - 1, SM_b - 1)

            if model_num == 6:
                loss_creative = F.cross_entropy(C_creative_fake,
                                                new_class_labels)
            else:
                if model_num != 1:
                    # Normalize SM-Divergence & Report mean
                    min_e, max_e = torch.min(entropy_GX_fake_vec), torch.max(
                        entropy_GX_fake_vec)
                    entropy_GX_fake_vec = (entropy_GX_fake_vec -
                                           min_e) / (max_e - min_e)
                    entropy_GX_fake = -entropy_GX_fake_vec.mean()
                loss_creative = -opt.Creative_weight * entropy_GX_fake

            disc_GX_fake_real = -torch.mean(D_creative_fake)
            total_loss_creative = loss_creative + disc_GX_fake_real

            all_loss = GC_loss + Euclidean_loss + reg_loss + reg_Wz_loss + total_loss_creative
            all_loss.backward()
            if model_num == 2 or model_num == 4 or model_num == 5:
                optimizer_SM_ab.step()
            optimizerG.step()
            reset_grad(nets)

        if it % opt.disp_interval == 0 and it:
            acc_real = (np.argmax(C_real.data.cpu().numpy(), axis=1)
                        == y_true.data.cpu().numpy()).sum() / float(
                            y_true.data.size()[0])
            acc_fake = (np.argmax(C_fake.data.cpu().numpy(), axis=1)
                        == y_true.data.cpu().numpy()).sum() / float(
                            y_true.data.size()[0])

            log_text = 'Iter-{}; rl: {:.4}%; fk: {:.4}%'.format(
                it, acc_real * 100, acc_fake * 100)
            with open(log_dir, 'a') as f:
                f.write(log_text + '\n')

        if it % opt.evl_interval == 0 and it > opt.disp_interval:
            netG.eval()
            cur_acc = eval_fakefeat_test(it, netG, dataset, param, result)
            cur_auc = eval_fakefeat_GZSL(netG, dataset, param, out_subdir,
                                         result)

            if cur_acc > result.best_acc:
                result.best_acc = cur_acc

            if cur_auc > result.best_auc:
                result.best_auc = cur_auc

                if it % opt.save_interval:
                    files2remove = glob.glob(out_subdir + '/Best_model*')
                    for _i in files2remove:
                        os.remove(_i)
                    torch.save(
                        {
                            'it': it + 1,
                            'state_dict_G': netG.state_dict(),
                            'state_dict_D': netD.state_dict(),
                            'random_seed': opt.manualSeed,
                            'log': log_text,
                        }, out_subdir +
                        '/Best_model_AUC_{:.3f}.tar'.format(cur_auc))

            netG.train()
    return result
Beispiel #3
0
def train():
    param = _param()
    dataset = LoadDataset_NAB(opt)
    param.X_dim = dataset.feature_dim

    data_layer = FeatDataLayer(dataset.labels_train,
                               dataset.pfc_feat_data_train,
                               dataset.seen_label_mapping, opt)
    result = Result()
    result_gzsl = Result()

    netG = _netG(dataset.text_dim, dataset.feature_dim).cuda()
    netG.apply(weights_init)
    print(netG)
    netD = _netD(dataset.train_cls_num + dataset.test_cls_num,
                 dataset.feature_dim).cuda()
    netD.apply(weights_init)
    print(netD)

    exp_info = 'NAB_EASY' if opt.splitmode == 'easy' else 'NAB_HARD'
    exp_params = 'Eu{}_Rls{}_RWz{}'.format(opt.CENT_LAMBDA, opt.REG_W_LAMBDA,
                                           opt.REG_Wz_LAMBDA)

    out_dir = 'out_' + str(opt.epsilon) + '/{:s}'.format(exp_info)
    out_subdir = 'out_' + str(opt.epsilon) + '/{:s}/{:s}'.format(
        exp_info, exp_params)

    if not os.path.exists('out_' + str(opt.epsilon)):
        os.mkdir('out_' + str(opt.epsilon))
    if not os.path.exists(out_dir):
        os.mkdir(out_dir)
    if not os.path.exists(out_subdir):
        os.mkdir(out_subdir)

    cprint(" The output dictionary is {}".format(out_subdir), 'red')
    log_dir = out_subdir + '/log_{:s}.txt'.format(exp_info)

    with open(log_dir, 'a') as f:
        f.write('Training Start:')
        f.write(strftime("%a, %d %b %Y %H:%M:%S +0000", gmtime()) + '\n')
        f.write("Running Parameter Logs")
        f.write(runing_parameters_logs)

    start_step = 0

    if opt.splitmode != 'easy':
        epochs = 1000

    if opt.resume:
        if os.path.isfile(opt.resume):
            print("=> loading checkpoint '{}'".format(opt.resume))
            checkpoint = torch.load(opt.resume)
            netG.load_state_dict(checkpoint['state_dict_G'])
            netD.load_state_dict(checkpoint['state_dict_D'])
            start_step = checkpoint['it']
            print(checkpoint['log'])
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))

    nets = [netG, netD]

    tr_cls_centroid = Variable(
        torch.from_numpy(dataset.tr_cls_centroid.astype('float32'))).cuda()
    optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(0.5, 0.9))
    optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(0.5, 0.9))

    for it in range(start_step, epochs):
        if it > opt.mode_change:
            train_text = Variable(
                torch.from_numpy(
                    dataset.train_text_feature.astype('float32'))).cuda()
            test_text = Variable(
                torch.from_numpy(
                    dataset.test_text_feature.astype('float32'))).cuda()
            z_train = Variable(torch.randn(dataset.train_cls_num,
                                           param.z_dim)).cuda()
            z_test = Variable(torch.randn(dataset.test_cls_num,
                                          param.z_dim)).cuda()

            _, train_text_feature = netG(z_train, train_text)
            _, test_text_feature = netG(z_test, test_text)

            dataset.semantic_similarity_check(
                opt.Knn,
                train_text_feature.data.cpu().numpy(),
                test_text_feature.data.cpu().numpy())
        """ Discriminator """
        for _ in range(5):
            blobs = data_layer.forward()
            feat_data = blobs['data']  # image data
            labels = blobs['labels'].astype(int)  # class labels
            true_labels = blobs['true_labels'].astype(int)

            text_feat = np.array(
                [dataset.train_text_feature[i, :] for i in labels])
            text_feat = Variable(torch.from_numpy(
                text_feat.astype('float32'))).cuda()
            X = Variable(torch.from_numpy(feat_data)).cuda()
            y_true = Variable(torch.from_numpy(
                true_labels.astype('int'))).cuda()

            z = Variable(torch.randn(opt.batchsize, param.z_dim)).cuda()

            # GAN's D loss
            D_real, C_real = netD(X)
            D_loss_real = torch.mean(D_real)
            C_loss_real = F.cross_entropy(C_real, y_true)
            DC_loss = -D_loss_real + C_loss_real
            DC_loss.backward()

            # GAN's D loss
            G_sample, _ = netG(z, text_feat)
            D_fake, C_fake = netD(G_sample)
            D_loss_fake = torch.mean(D_fake)
            C_loss_fake = F.cross_entropy(C_fake, y_true)
            DC_loss = D_loss_fake + C_loss_fake
            DC_loss.backward()

            # train with gradient penalty (WGAN_GP)
            grad_penalty = calc_gradient_penalty(netD, X.data, G_sample.data)
            grad_penalty.backward()

            Wasserstein_D = D_loss_real - D_loss_fake
            optimizerD.step()
            reset_grad(nets)
        """ Generator """
        for _ in range(1):
            blobs = data_layer.forward()
            feat_data = blobs['data']  # image data
            labels = blobs['labels'].astype(int)  # class labels
            true_labels = blobs['true_labels'].astype(
                int)  #True seen label class

            text_feat = np.array(
                [dataset.train_text_feature[i, :] for i in labels])
            text_feat = Variable(torch.from_numpy(
                text_feat.astype('float32'))).cuda()

            X = Variable(torch.from_numpy(feat_data)).cuda()
            y_dummy = Variable(torch.from_numpy(labels.astype('int'))).cuda()
            y_true = Variable(torch.from_numpy(
                true_labels.astype('int'))).cuda()

            z = Variable(torch.randn(opt.batchsize, param.z_dim)).cuda()

            G_sample, _ = netG(z, text_feat)
            D_fake, C_fake = netD(G_sample)
            _, C_real = netD(X)

            # GAN's G loss
            G_loss = torch.mean(D_fake)
            # Auxiliary classification loss
            C_loss = (F.cross_entropy(C_real, y_true) +
                      F.cross_entropy(C_fake, y_true)) / 2

            GC_loss = -G_loss + C_loss

            # Centroid loss
            Euclidean_loss = Variable(torch.Tensor([0.0])).cuda()
            Correlation_loss = Variable(torch.Tensor([0.0])).cuda()

            if opt.CENT_LAMBDA != 0:
                for i in range(dataset.train_cls_num):
                    sample_idx = (y_dummy == i).data.nonzero().squeeze()
                    if sample_idx.numel() == 0:
                        Euclidean_loss += 0.0
                    else:
                        G_sample_cls = G_sample[sample_idx, :]
                        if sample_idx.numel() != 1:
                            generated_mean = G_sample_cls.mean(dim=0)
                        else:
                            generated_mean = G_sample_cls

                        Euclidean_loss += (
                            generated_mean -
                            tr_cls_centroid[i]).pow(2).sum().sqrt()

                        for n in range(dataset.Neighbours):
                            Neighbor_correlation = cosine_similarity(
                                generated_mean.data.cpu().numpy().reshape(
                                    (1, dataset.feature_dim)),
                                tr_cls_centroid[dataset.idx_mat[
                                    i, n]].data.cpu().numpy().reshape(
                                        (1, dataset.feature_dim)))

                            lower_limit = dataset.semantic_similarity_seen[
                                i, n] - opt.epsilon
                            upper_limit = dataset.semantic_similarity_seen[
                                i, n] + opt.epsilon

                            lower_limit = torch.as_tensor(
                                lower_limit.astype('float'))
                            upper_limit = torch.as_tensor(
                                upper_limit.astype('float'))
                            corr = torch.as_tensor(
                                Neighbor_correlation[0][0].astype('float'))
                            margin = (torch.max(
                                corr - corr,
                                corr - upper_limit))**2 + (torch.max(
                                    corr - corr, lower_limit - corr))**2
                            Correlation_loss += margin

                Euclidean_loss *= 1.0 / dataset.train_cls_num * opt.CENT_LAMBDA
                Correlation_loss = Correlation_loss * opt.correlation_penalty

            # ||W||_2 regularization
            reg_loss = Variable(torch.Tensor([0.0])).cuda()
            if opt.REG_W_LAMBDA != 0:
                for name, p in netG.named_parameters():
                    if 'weight' in name:
                        reg_loss += p.pow(2).sum()
                reg_loss.mul_(opt.REG_W_LAMBDA)

            # ||W_z||21 regularization, make W_z sparse
            reg_Wz_loss = Variable(torch.Tensor([0.0])).cuda()
            if opt.REG_Wz_LAMBDA != 0:
                Wz = netG.rdc_text.weight
                reg_Wz_loss = Wz.pow(2).sum(dim=0).sqrt().sum().mul(
                    opt.REG_Wz_LAMBDA)

            all_loss = GC_loss + Euclidean_loss + reg_loss + reg_Wz_loss + Correlation_loss
            all_loss.backward()
            optimizerG.step()
            reset_grad(nets)

        if it > opt.unseen_start:
            for _ in range(1):
                # Zero shot Discriminator is training
                zero_shot_labels = np.random.randint(
                    dataset.test_cls_num,
                    size=opt.zeroshotbatchsize).astype(int)
                zero_shot_true_labels = np.array([
                    dataset.unseen_label_mapping[i] for i in zero_shot_labels
                ])
                zero_text_feat = np.array([
                    dataset.test_text_feature[i, :] for i in zero_shot_labels
                ])

                zero_text_feat = Variable(
                    torch.from_numpy(zero_text_feat.astype('float32'))).cuda()
                zero_y_true = Variable(
                    torch.from_numpy(
                        zero_shot_true_labels.astype('int'))).cuda()
                z = Variable(torch.randn(opt.zeroshotbatchsize,
                                         param.z_dim)).cuda()

                # GAN's D loss
                G_sample_zero, _ = netG(z, zero_text_feat)
                _, C_fake_zero = netD(G_sample_zero)
                C_loss_fake_zero = F.cross_entropy(C_fake_zero, zero_y_true)
                C_loss_fake_zero.backward()

                optimizerD.step()
                reset_grad(nets)

            for _ in range(1):
                # Zero shot Generator is training
                zero_shot_labels = np.random.randint(
                    dataset.test_cls_num,
                    size=opt.zeroshotbatchsize).astype(int)
                zero_shot_true_labels = np.array([
                    dataset.unseen_label_mapping[i] for i in zero_shot_labels
                ])
                zero_text_feat = np.array([
                    dataset.test_text_feature[i, :] for i in zero_shot_labels
                ])

                zero_text_feat = Variable(
                    torch.from_numpy(zero_text_feat.astype('float32'))).cuda()
                zero_y_true = Variable(
                    torch.from_numpy(
                        zero_shot_true_labels.astype('int'))).cuda()
                y_dummy_zero = Variable(
                    torch.from_numpy(zero_shot_labels.astype('int'))).cuda()
                z = Variable(torch.randn(opt.zeroshotbatchsize,
                                         param.z_dim)).cuda()

                # GAN's D loss
                G_sample_zero, _ = netG(z, zero_text_feat)
                _, C_fake_zero = netD(G_sample_zero)
                C_loss_fake_zero = F.cross_entropy(C_fake_zero, zero_y_true)

                Correlation_loss_zero = Variable(torch.Tensor([0.0])).cuda()

                if opt.CENT_LAMBDA != 0:
                    for i in range(dataset.test_cls_num):
                        sample_idx = (
                            y_dummy_zero == i).data.nonzero().squeeze()
                        if sample_idx.numel() != 0:
                            G_sample_cls = G_sample_zero[sample_idx, :]

                            if sample_idx.numel() != 1:
                                generated_mean = G_sample_cls.mean(dim=0)
                            else:
                                generated_mean = G_sample_cls

                            for n in range(dataset.Neighbours):
                                Neighbor_correlation = cosine_similarity(
                                    generated_mean.data.cpu().numpy().reshape(
                                        (1, dataset.feature_dim)),
                                    tr_cls_centroid[dataset.unseen_idx_mat[
                                        i, n]].data.cpu().numpy().reshape(
                                            (1, dataset.feature_dim)))

                                lower_limit = dataset.semantic_similarity_unseen[
                                    i, n] - opt.epsilon
                                upper_limit = dataset.semantic_similarity_unseen[
                                    i, n] + opt.epsilon

                                lower_limit = torch.as_tensor(
                                    lower_limit.astype('float'))
                                upper_limit = torch.as_tensor(
                                    upper_limit.astype('float'))
                                corr = torch.as_tensor(
                                    Neighbor_correlation[0][0].astype('float'))

                                margin = (torch.max(
                                    corr - corr,
                                    corr - upper_limit))**2 + (torch.max(
                                        corr - corr, lower_limit - corr))**2

                                Correlation_loss_zero += margin

                    Correlation_loss_zero = Correlation_loss_zero * opt.correlation_penalty

                # ||W||_2 regularization
                reg_loss_zero = Variable(torch.Tensor([0.0])).cuda()
                if opt.REG_W_LAMBDA != 0:
                    for name, p in netG.named_parameters():
                        if 'weight' in name:
                            reg_loss_zero += p.pow(2).sum()
                    reg_loss_zero.mul_(opt.REG_W_LAMBDA)

                # ||W_z||21 regularization, make W_z sparse
                reg_Wz_loss_zero = Variable(torch.Tensor([0.0])).cuda()
                if opt.REG_Wz_LAMBDA != 0:
                    Wz = netG.rdc_text.weight
                    reg_Wz_loss_zero = Wz.pow(2).sum(dim=0).sqrt().sum().mul(
                        opt.REG_Wz_LAMBDA)

                all_loss = C_loss_fake_zero + reg_loss_zero + reg_Wz_loss_zero + Correlation_loss_zero
                all_loss.backward()
                optimizerG.step()
                reset_grad(nets)

        if it % opt.disp_interval == 0 and it:
            acc_real = (np.argmax(C_real.data.cpu().numpy(), axis=1)
                        == y_true.data.cpu().numpy()).sum() / float(
                            y_true.data.size()[0])
            acc_fake = (np.argmax(C_fake.data.cpu().numpy(), axis=1)
                        == y_true.data.cpu().numpy()).sum() / float(
                            y_true.data.size()[0])

            log_text = 'Iter-{}; Was_D: {:.4}; Euc_ls: {:.4}; reg_ls: {:.4}; Wz_ls: {:.4}; G_loss: {:.4}; Correlation_loss : {:.4} ; D_loss_real: {:.4};' \
                       ' D_loss_fake: {:.4}; rl: {:.4}%; fk: {:.4}%'.format(it, Wasserstein_D.item(),  Euclidean_loss.item(), reg_loss.item(),reg_Wz_loss.item(),
                                G_loss.item(), Correlation_loss.item() , D_loss_real.item(), D_loss_fake.item(), acc_real * 100, acc_fake * 100)
            log_text1 = ""

            if it > opt.unseen_start:
                acc_fake_zero = (np.argmax(
                    C_fake_zero.data.cpu().numpy(),
                    axis=1) == zero_y_true.data.cpu().numpy()).sum() / float(
                        zero_y_true.data.size()[0])

                log_text1 = 'Zero_Shot_Iter-{}; Correlation_loss : {:.4}; fk: {:.4}%'.format(
                    it, Correlation_loss_zero.item(), acc_fake_zero * 100)

            print(log_text)
            print(log_text1)
            with open(log_dir, 'a') as f:
                f.write(log_text + '\n')
                f.write(log_text1 + '\n')

        if it % opt.evl_interval == 0 and it >= 20:
            netG.eval()
            eval_fakefeat_test(it, netG, netD, dataset, param, result)
            eval_fakefeat_GZSL(it, netG, dataset, param, result_gzsl)
            if result.save_model:
                files2remove = glob.glob(out_subdir + '/Best_model*')
                for _i in files2remove:
                    os.remove(_i)
                torch.save(
                    {
                        'it': it + 1,
                        'state_dict_G': netG.state_dict(),
                        'state_dict_D': netD.state_dict(),
                        'random_seed': opt.manualSeed,
                        'log': log_text,
                        'Zero Shot Acc': result.acc_list[-1],
                        'Generalized Zero Shot Acc': result_gzsl.acc_list[-1]
                    }, out_subdir + '/Best_model_Acc_' +
                    str(result.acc_list[-1]) + '_AUC_' +
                    str(result_gzsl.acc_list[-1]) + '_' + '.tar')
            netG.train()

        if it % opt.save_interval == 0 and it:
            torch.save(
                {
                    'it': it + 1,
                    'state_dict_G': netG.state_dict(),
                    'state_dict_D': netD.state_dict(),
                    'random_seed': opt.manualSeed,
                    'log': log_text,
                    'Zero Shot Acc': result.acc_list[-1],
                    'Generalized Zero Shot Acc': result_gzsl.acc_list[-1]
                }, out_subdir + '/Iter_{:d}.tar'.format(it))
            cprint('Save model to ' + out_subdir + '/Iter_{:d}.tar'.format(it),
                   'red')