INPUT_H = 28
INPUT_W = 28
BATCH_SIZE = 100
LATENT_DIM = 50
GEN_LEARNING_RATE = 1e-4
GEN_STEPS_PER_CYCLE = 1
DISC_LEARNING_RATE = 1e-5
DISC_STEPS_PER_CYCLE = 1


def get_fake_labels(num):
    fake_classes = torch.Tensor([i % N_CLASSES for i in range(num)]).to(device)
    return label_to_onehot(fake_classes, N_CLASSES).float()


C_train = label_to_onehot(Y_train, N_CLASSES).float()
C_test = label_to_onehot(Y_test, N_CLASSES).float()
""" Declare generator and discriminator """
gen = Generator(LATENT_DIM, N_CLASSES, INPUT_H, INPUT_W)
disc = Discriminator(INPUT_H, INPUT_W, N_CLASSES)
""" Declare data iterators """
train_dataset = TensorDataset(X_train, C_train)
test_dataset = TensorDataset(X_test, C_test)

train_iter = DataLoader(train_dataset, batch_size=BATCH_SIZE)
test_iter = DataLoader(test_dataset, batch_size=BATCH_SIZE)
""" Optimizers """
gen_opt = torch.optim.Adam(gen.parameters(), lr=GEN_LEARNING_RATE)
disc_opt = torch.optim.Adam(disc.parameters(), lr=DISC_LEARNING_RATE)
""" Train """
gen_losses = []
def get_fake_labels(num):
    fake_classes = torch.Tensor([i % N_CLASSES for i in range(num)]).to(device)
    return label_to_onehot(fake_classes, N_CLASSES).float()
Ejemplo n.º 3
0
device = torch.device("cpu")

np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

if not os.path.exists(OUTPUT_DIR):
    print("Creating directory {}".format(OUTPUT_DIR))
    os.makedirs(OUTPUT_DIR)

# I. Get data

(X_train, Y_train), (X_test, Y_test) = load_MNIST("data")

C_train = label_to_onehot(Y_train, N_CLASSES)
C_test = label_to_onehot(Y_test, N_CLASSES)

# II. Get trained model 
dec = torch.load(TRAINED_MODEL_PATH)
csf = torch.load(TRAINED_CSF_PATH)

# III. Get outputs from trained model 
# # Test trained model 
# n_rows = N_CLASSES
# n_cols = 10
# N = n_rows * n_cols

# prior = torch.distributions.Normal(0, 1)
# z = prior.sample((N, LATENT_DIM)).type(torch.float64)
# Y_gen = torch.tensor([i for i in range(n_cols) for j in range(n_rows)]).reshape(N, 1)