示例#1
0
def main():
    train_loader, test_loader = make_loader()
    print("train VAE")
    vae = model.VAE(6, 2)
    vae_optimizer = torch.optim.Adam(vae.parameters(), lr=0.001)
    losses = []
    for i in range(3):
        for data in train_loader:
            vae.zero_grad()
            loss = vae.loss(data['data'])
            loss.backward()
            vae_optimizer.step()
            losses.append(loss.cpu().detach().numpy())
        print("EPOCH: {} loss: {}".format(i, np.average(losses)))
    torch.save(vae.state_dict(), './saved/vae')

    print("train Estimator")
    estimator = model.Estimator()
    estimator_optimizer = torch.optim.Adam(vae.parameters(), lr=0.001)
    losses = []
    for i in range(3):
        for data in train_loader:
            estimator.zero_grad()
            _, z = vae(data['data'])
            loss = estimator.loss(z, data['treat'], data['outcome'])
            loss.backward()
            estimator_optimizer.step()
            losses.append(loss.cpu().detach().numpy())
        print("EPOCH: {} loss: {}".format(i, np.average(losses)))
    torch.save(estimator.state_dict(), './saved/estimator')
示例#2
0
def set_model(args):
    mlp = model.VAE(args.in_features, args.latent_size)
    pytorch_total_params = sum(p.numel() for p in mlp.parameters()
                               if p.requires_grad)
    print("Number of traninable parameter: {0}".format(pytorch_total_params))

    optimizer, scheduler = configure_optimizers(mlp)

    if args.warm_start is not None:
        # Load the save model
        checkpoint = torch.load(args.locations['model_loc'] + '/' +
                                args.warm_start,
                                map_location=args.device)
        mlp.load_state_dict(checkpoint['model_state_dict'])
        mlp.to(args.device)
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        loss = checkpoint['training_loss']
        epoch = checkpoint['epoch']
        args.model_name = args.model_name.replace('.tar',
                                                  '_{0}.tar'.format("cont"))
    else:
        mlp.to(args.device)
    print("Model's state_dict:")
    for param_tensor in mlp.state_dict():
        print(param_tensor, "\t", mlp.state_dict()[param_tensor].size())

    return mlp, optimizer, scheduler
示例#3
0
def load_model(saved_vae, stored_info, device, cache_path=str(Path('../tmp')), seed=None):
    stored_info = stored_info.split(os.sep)[-1]
    cache_file =  os.path.join(cache_path, stored_info)

    start_load = time.time()
    print(f"Fetching cached info at {cache_file}")
    with open(cache_file, "rb") as f:
        dataset, z_size, condition_size, condition_on, decoder_hidden_size, encoder_hidden_size, n_encoder_layers = pickle.load(f)
    end_load = time.time()
    print(f"Cache {cache_file} loaded (load time: {end_load - start_load:.2f}s)")

    if os.path.exists(saved_vae):
        print(f"Found saved model {saved_vae}")
        start_load_model = time.time()

        e = model.EncoderRNN(dataset.input_side.n_words, encoder_hidden_size, z_size, n_encoder_layers, bidirectional=True)
        d = model.DecoderRNN(z_size, dataset.trn_split.n_conditions, condition_size, decoder_hidden_size, dataset.input_side.n_words, 1, word_dropout=0)
        vae = model.VAE(e, d).to(device)
        vae.load_state_dict(torch.load(saved_vae, map_location=lambda storage, loc: storage))
        vae.eval()
        print(f"Trained for {vae.steps_seen} steps (load time: {time.time() - start_load_model:.2f}s)")

        print("Setting new random seed")
        if seed is None:
            # TODO: torch.manual_seed(1999) in model.py is affecting this
            new_seed = int(time.time())
            new_seed = abs(new_seed) % 4294967295 # must be between 0 and 4294967295
        else:
            new_seed = seed
        torch.manual_seed(new_seed)

        random_state = np.random.RandomState(new_seed)
        #random_state.shuffle(dataset.trn_pairs)

    return vae, dataset, z_size, random_state
示例#4
0
def set_model(args):
    mlp = model.VAE(args.in_features)
    pytorch_total_params = sum(p.numel() for p in mlp.parameters()
                               if p.requires_grad)
    print("Number of traninable parameter: {0}".format(pytorch_total_params))

    # if args.loss == "mae":
    #     loss_function = torch.nn.functional.l1_loss #torch.nn.L1Loss()
    # elif args.loss == "mse":
    #     loss_function = torch.nn.functional.mse_loss #torch.nn.MSELoss()
    # elif args.loss == "mink":
    #     loss_function = minkowski_error
    # elif args.loss == "huber":
    #     loss_function = torch.nn.functional.smooth_l1_loss
    optimizer, scheduler = configure_optimizers(mlp)

    if args.warm_start:
        # Load the save model
        checkpoint = torch.load(args.locations['model_loc'] + '/' +
                                args.model_name,
                                map_location=args.device)
        mlp.load_state_dict(checkpoint['model_state_dict'])
        mlp.to(args.device)
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        loss = checkpoint['loss']
        epoch = checkpoint['epoch']
        args.model_name = args.model_name.replace(
            '.tar', '_{0}.tar'.format(args.region))
    else:
        mlp.to(args.device)
    print("Model's state_dict:")
    for param_tensor in mlp.state_dict():
        print(param_tensor, "\t", mlp.state_dict()[param_tensor].size())

    return mlp, loss_function, optimizer, scheduler
示例#5
0
def train(config):
    #prepare dataset
    dataset = data_generator.Dataset_URMP()
    dataset.get_wav('train')
    # dataset.get_label('train')
    dataset.load_songs('train')
    train_generator = dataset.batch_generator('train')
    dataset.get_wav('val')
    # dataset.get_label('train')
    dataset.load_songs('val')
    val_generator = dataset.batch_generator('val')

    vae = model.VAE().build_models()
    #path for saving model weights
    ckpt_path = config["train"]["path"]

    #callback_1 to save best model
    file_name1 = 'vn_tpt__fl_ckp_{epoch}.h5'
    file_path1 = os.path.join(ckpt_path, file_name1)
    callbacks_1 = tfc.ModelCheckpoint(filepath=file_path1,
                                      save_weight_only=True)

    # callback_2 to visualize loss during training
    file_path2 = os.path.join(ckpt_path, 'log')
    callbacks_2 = tfc.TensorBoard(
        log_dir=file_path2,
        histogram_freq=0,  # How often to log histogram visualizations
        embeddings_freq=0,  # How often to log embedding visualizations
        update_freq='epoch'  # How often to write logs (default: once per epoch)
    )

    # callbacks_3 to save weights after every epoch
    callbacks_3 = SaveWeights(ckpt_path, '4_instr', epochs=5)

    #callbacks 4
    callbacks_4 = tfc.EarlyStopping(patience=["train"]["early_stopping_epoch"],
                                    verbose=1,
                                    monitor='loss')
    #callbacks_5 to plot loss after training
    history = LossHistory()

    GPU_Memory = False
    #allocate a subset of the available memory
    if GPU_Memory:
        gpus = tf.config.experimental.list_physical_devices('GPU')
        tf.config.experimental.set_memory_growth(gpus[0], True)

    print('training starts......')
    vae.fit_generator(train_generator,
                      steps_per_epoch=["train"]["num_steps_train"],
                      epochs=["train"]["num_epochs"],
                      verbose=["train"]["verbosity"],
                      callbacks=[callbacks_3, callbacks_4, history],
                      validation_data=val_generator,
                      validation_steps=["train"]["num_steps_val"])

    history.loss_plot('epoch')
示例#6
0
def predict_hanning(mix_path, model_weight_path, save_path):
    # mix_path : path of mix music for predict
    # model_weight_path: weight of model ready for loading
    # save_path: list of paths for saving

    sep_input_shape = (1,16384,1)
    print("start predicting......")
    mix_sequence = load(mix_path)
    assert (len(mix_sequence.shape) == 2)

    # Preallocate source predictions (same shape as input mixture)
    sample_length = mix_sequence.shape[0]
    #the num of zero matrix depends on num of output source
    vn_pre = np.zeros(mix_sequence.shape, np.float32) #prediction for one instrument
    tpt_pre = np.zeros(mix_sequence.shape, np.float32) #prediction for one instrument
    fl_pre = np.zeros(mix_sequence.shape, np.float32) #prediction for one instrument

    input_length = sep_input_shape[1]
    #load model
    vae = model.VAE().build_models()
    vae.load_weights(model_weight_path)
    # Iterate through total length
    for source_pos in range(0, sample_length, int(input_length/2)):
        # If last segment small than input_length, then take very end segment instead
        if source_pos + input_length > sample_length:
            source_pos = sample_length - input_length

        mix_part = mix_sequence[source_pos:source_pos + input_length,:]
        mix_part = np.expand_dims((np.hamming(M=input_length) * np.squeeze(mix_part, axis=-1)),axis=-1)
        #let the shape of input same as shape of training process, set batch to 1
        mix_part = np.expand_dims(mix_part, axis=0)

        predict_source = np.squeeze(vae.predict(mix_part,batch_size=1), axis=0)

        # Save predictions for concate
        vn_pre[source_pos:source_pos + input_length, 0] += predict_source[:, 0]
        tpt_pre[source_pos:source_pos + input_length, 0] +=  predict_source[:, 0]
        fl_pre[source_pos:source_pos + input_length, 0] +=  predict_source[:, 0]


    save_wav(vn_pre,  save_path[0])
    # print("finish fl predict......")
    save_wav(tpt_pre,  save_path[1])
    save_wav(fl_pre,  save_path[2])
示例#7
0
def train(expr_in,
          vae_lr=1e-4,
          epochs=500,
          info_step=10,
          batch_size=50,
          latent_dim=2,
          f="nb",
          log=True,
          scale=True):
    # Preprocessing
    expr_in[expr_in < 0] = 0.0

    if log:
        expr_in = np.log2(expr_in + 1)
    if scale:
        for i in range(expr_in.shape[0]):
            expr_in[i, :] = expr_in[i, :] / np.max(expr_in[i, :])

    # Number of data samples
    n_sam = expr_in.shape[0]
    # Dimension of input data
    in_dim = expr_in.shape[1]
    # Build VAE model and its optimizer
    lmd = fitting.fit(expr_in, f)
    model_vae = model.VAE(in_dim=in_dim, latent_dim=latent_dim, f=f, lmd=lmd)
    optimizer_vae = tf.keras.optimizers.Adam(vae_lr)

    # Training
    for epoch in range(1, epochs + 1):
        # Minibatch for VAE training
        vae_train_set = tf.data.Dataset.from_tensor_slices(expr_in).shuffle(
            n_sam).batch(batch_size)
        # Batch training
        for vae_batch in vae_train_set:
            # Update VAE model
            rec_loss, kl_loss, rank_loss = update_model(
                model_vae, vae_batch, optimizer_vae, losses.vae_loss)
        # Print training info
        if epoch % info_step == 0:
            print("Epoch", epoch, " rec_loss: ", rec_loss.numpy(),
                  " kl_loss: ", kl_loss.numpy(), " rank_loss: ",
                  rank_loss.numpy())

    return model_vae
示例#8
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-i", type=str, help='input image')
    parser.add_argument("-w", type=str, default='', help='Model weights')

    args = parser.parse_args()

    assert isinstance(args.i, str)
    assert isinstance(args.w, str)

    # load model
    dump = torch.load(args.w)
    vae = model.VAE(dump['input_shape'], dump['z_dim']).cuda()
    vae.load_state_dict(dump['state_dict'])
    vae.eval()

    # load image
    img = np.asarray(Image.open(args.i).resize((112, 128))) / 255
    # img = np.asarray(Image.open(args.i)) / 255
    img = np.transpose(img, [2, 0, 1])
    img_v = torch.tensor(img, dtype=torch.float32).unsqueeze(0).cuda()
    img_v = torch.cat((img_v, img_v, img_v, img_v, img_v), 0)
    _, __, output_v = vae.forward(img_v)
    out_img = output_v.detach().squeeze(0).cpu().numpy()

    # plot
    fig = plt.figure()
    plt.subplot(3, 3, 1, xticks=[], yticks=[])
    plt.imshow(np.transpose(img, [1, 2, 0]))
    plt.subplot(3, 3, 2, xticks=[], yticks=[])
    plt.imshow(np.transpose(out_img[0], [1, 2, 0]))
    plt.subplot(3, 3, 3, xticks=[], yticks=[])
    plt.imshow(np.transpose(out_img[1], [1, 2, 0]))
    plt.subplot(3, 3, 4, xticks=[], yticks=[])
    plt.imshow(np.transpose(out_img[2], [1, 2, 0]))
    plt.subplot(3, 3, 5, xticks=[], yticks=[])
    plt.imshow(np.transpose(out_img[3], [1, 2, 0]))
    plt.subplot(3, 3, 6, xticks=[], yticks=[])
    plt.imshow(np.transpose(out_img[4], [1, 2, 0]))
    plt.show()

    print(out_img[0] - out_img[1])
def get_pretrained_vae(bars, pianoroll=False, transpose=False, verbose=False):
    vae = model.VAE(bars, pianoroll)
    vae.train()

    location = data.path_to_root + "Models/Final"
    stride = 1 if transpose else 3
    dset_name = data.get_dataset_name("train", bars, stride, pianoroll, transpose)
    if pianoroll:
        dset_name = "pianoroll_" + dset_name

    checkpoint = get_checkpoint(dset_name, location)
    if verbose:
        print("loaded Checkpoint")

    prefix = "last_"
    pretrained_dict = checkpoint[prefix + 'vae_state_dict']

    pretrained_l2 = {k:v for k, v in pretrained_dict.items() if (k[0:7] == "lstm_l2" or k[0:4] == "fc_4")}

    vae.load_state_dict(pretrained_l2, strict=False)
    return vae
def get_trained_vae(bars, pianoroll=False, transpose=False, verbose=False):
    vae = model.VAE(bars, pianoroll)
    vae.eval()

    location = data.path_to_root + "Models/Final"
    stride = 1 if transpose else 3
    dset_name = data.get_dataset_name("train", bars, stride, pianoroll, transpose)

    checkpoint = get_checkpoint(dset_name, location)
    if verbose:
        print("loaded Checkpoint " + dset_name)

    prefix = "last_"
    vae.load_state_dict(checkpoint[prefix + 'vae_state_dict'])

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    vae.to(device)
    if verbose:
        print("using device " + str(device) + "\n")

    return vae
示例#11
0
def train_network(train_loader, test_loader):
    model = m.VAE().cuda()
    optimizer = optim.Adam(model.parameters(), lr=5e-4)
    num_epochs = 10000
    vis = visdom.Visdom('http://192.168.0.14')

    best_test_loss = float("inf")
    for i in range(num_epochs):
        print("Epoch: ", (i + 1))
        start = time.time()
        train_loss = train_step(train_loader, model, optimizer)
        end = time.time()
        print("Train Time: ", end - start)
        print("Train Loss: ", train_loss)
        #test_loss = test_step(test_loader, model)
        test_loss = 0.
        print("Test Loss: ", test_loss)
        if test_loss <= best_test_loss:
            print("SAVING NEW MODEL")
            best_test_loss = test_loss
            save_model(model)
            test_samples, reconstructions = get_examples(test_loader, model)
            visualize_examples(vis, test_samples, reconstructions,
                               'test_samples')
            train_samples, reconstructions = get_examples(train_loader, model)
            visualize_examples(vis, train_samples, reconstructions,
                               'train_samples')
            num_samples = 5
            sample_labels = []
            for i in range(num_samples):
                sample_labels.append(random.randint(0, 9))
            t_sample_labels = densify_labels(sample_labels)
            t_sample_labels = Variable(t_sample_labels).cuda()
            samples = model.generate(num_samples, t_sample_labels)
            visualize_samples(vis, samples, sample_labels)
        print("Best Loss: ", best_test_loss)
示例#12
0
def train(**kwargs):
    cfg = Config()
    for k, v in kwargs.items():
        setattr(cfg, k, v)

    dataset = datasets.MNIST(root=cfg.train_path,
                             train=True,
                             transform=transforms.ToTensor(),
                             download=cfg.download)

    data_loader = torch.utils.data.DataLoader(dataset=dataset,
                                              batch_size=cfg.batch_size,
                                              shuffle=cfg.shuffle,
                                              num_workers=cfg.num_workers)
    vae = model.VAE(cfg)
    if cfg.use_gpu:
        vae.cuda()

    optimizer = torch.optim.Adam(vae.parameters(), lr=cfg.lr)
    data_iter = iter(data_loader)
    fixed_x, _ = next(data_iter)
    torchvision.utils.save_image(fixed_x.cpu(), './data/real_images.png')
    fixed_x = Variable(fixed_x.view(fixed_x.size(0), -1))
    if cfg.use_gpu:
        fixed_x = fixed_x.cuda()

    plt.ion()
    for epoch in range(cfg.epoch):
        for i, (images, _) in enumerate(data_loader):

            images = Variable(images.view(images.size(0), -1))
            if cfg.use_gpu:
                images = images.cuda()
            out, mean, log_var = vae(images)
            reconst_loss = F.binary_cross_entropy(out,
                                                  images,
                                                  size_average=False)
            kl_divergence = torch.sum(
                0.5 * (mean**2 + torch.exp(log_var) - log_var - 1))

            total_loss = reconst_loss + kl_divergence
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            if i % 100 == 0:
                plt.cla()
                plt.subplot(1, 2, 1)
                plt.imshow(images.data[0].view(28, 28).cpu().numpy(),
                           cmap="gray")

                plt.subplot(1, 2, 2)
                plt.imshow(out.data[0].view(28, 28).cpu().numpy(), cmap="gray")
                plt.draw()
                plt.pause(0.01)
        reconst_images, _, _ = vae(fixed_x)
        reconst_images = reconst_images.view(reconst_images.size(0), 1, 28, 28)
        torchvision.utils.save_image(
            reconst_images.data.cpu(),
            './data/reconst_images_%d.png' % (epoch + 1))
    plt.ioff()
    plt.show()
示例#13
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-i1", type=str, help='input image')
    parser.add_argument("-i2", type=str, help='input image')
    parser.add_argument("-w", type=str, default='', help='Model weights')
    parser.add_argument("-a",
                        type=float,
                        default=0.5,
                        help='interpolation between 0 and 1')

    args = parser.parse_args()

    assert isinstance(args.i1, str)
    assert isinstance(args.i2, str)
    assert isinstance(args.w, str)

    # load model
    dump = torch.load(args.w)
    vae = model.VAE(dump['input_shape'], dump['z_dim']).cuda()
    vae.load_state_dict(dump['state_dict'])
    vae.eval()

    # load image
    img1 = np.asarray(Image.open(args.i1).resize((112, 128))) / 255
    img2 = np.asarray(Image.open(args.i2).resize((112, 128))) / 255
    img1 = np.transpose(img1, [2, 0, 1])
    img2 = np.transpose(img2, [2, 0, 1])
    img1_v = torch.tensor(img1, dtype=torch.float32).unsqueeze(0).cuda()
    img2_v = torch.tensor(img2, dtype=torch.float32).unsqueeze(0).cuda()
    img_v = torch.cat((img1_v, img2_v), 0)
    mu, log_var, output_v = vae.forward(img_v)
    out_img = output_v.detach().squeeze(0).cpu().numpy()

    # obtain median face
    # mu = mu[0] + args.alpha * (mu[1] - mu[0])
    # mu = mu.unsqueeze(0)
    # out_mean = vae.forward_decoder(mu)
    # mean_img = out_mean.detach().squeeze(0).cpu().numpy()

    # plot
    fig = plt.figure()
    # plt.ion()
    # plt.show()
    # while True:
    # alpha = m.cos(time.time()) * 0.5 + 0.5
    alpha = args.a
    mu_ = (mu[0] + alpha * (mu[1] - mu[0])).unsqueeze(0)
    out_mean = vae.forward_decoder(mu_)
    mean_img = out_mean.detach().squeeze(0).cpu().numpy()

    plt.subplot(3, 2, 1, xticks=[], yticks=[])
    plt.imshow(np.transpose(img1, [1, 2, 0]))
    plt.subplot(3, 2, 2, xticks=[], yticks=[])
    plt.imshow(np.transpose(img2, [1, 2, 0]))
    plt.subplot(3, 2, 3, xticks=[], yticks=[])
    plt.imshow(np.transpose(out_img[0], [1, 2, 0]))
    plt.subplot(3, 2, 4, xticks=[], yticks=[])
    plt.imshow(np.transpose(out_img[1], [1, 2, 0]))
    plt.subplot(3, 2, 5, xticks=[], yticks=[])
    plt.imshow(np.transpose(mean_img, [1, 2, 0]))
    plt.pause(0.02)
    # plt.draw()
    plt.show()
示例#14
0
data = torch.load(args.data)
# args.max_len = data["max_word_len"]
args.max_len = 30
args.vocab_size = data['vocab_size']
args.pre_w2v = data['pre_w2v']
args.idx2word = {v: k for k, v in data['word2idx'].items()}

dl = DataLoader(data['train'], args.max_len, args.batch_size)
training_data = dl.ds_loader
'''
build model
'''
import model
from optim import ScheduledOptim, SetCriterion
# from metric import SetCriterion
vae = model.VAE(args)
if use_cuda:
    vae = vae.cuda()

criterion = SetCriterion(data['word2idx'],
                         label_ignore=['。', ',', '、', ' '],
                         ignore_index=PAD)

optimizer = ScheduledOptim(
    torch.optim.Adam(vae.parameters(), betas=(0.9, 0.98), eps=1e-09),
    args.embed_dim, args.n_warmup_steps, vae.parameters(), args.clip)
'''
train model
'''
import time
from tqdm import tqdm
示例#15
0
import data
from torch.utils.data import DataLoader
import os
from tqdm import tqdm

torch.set_grad_enabled(False)
if torch.cuda.is_available():
    device_name = 'cuda'
else:
    device_name = 'cpu'
device = torch.device(device_name)
cpu = torch.device('cpu')

# load model
dump = torch.load('vae-200.dat', map_location=device_name)
vae = model.VAE(dump['input_shape'], dump['z_dim']).to(device)
vae.load_state_dict(dump['state_dict'])
vae.eval()

z_dim = dump['z_dim']

# data
feature = 'Eyeglasses'
vector = torch.zeros(z_dim, dtype=torch.float32)

for positive in [True, False]:

    dataset = data.Dataset('faces/celeba-dataset',
                           'Eyeglasses',
                           positive=positive)
    dataset = DataLoader(dataset, batch_size=32, num_workers=4)
示例#16
0
    print(args, file=f)

# retrieve dataloader
trainset = NucleiDataset(datadir=args.datadir, mode="train")
testset = NucleiDataset(datadir=args.datadir, mode="test")

train_loader = DataLoader(
    trainset, batch_size=args.batch_size, drop_last=True, shuffle=True
)
test_loader = DataLoader(
    testset, batch_size=args.batch_size, drop_last=False, shuffle=False
)

print("Data loaded")

model = AENet.VAE(latent_variable_size=args.nz, batchnorm=True)
if args.conditional:
    netCondClf = AENet.Simple_Classifier(nz=args.nz)

if args.pretrained_file is not None:
    model.load_state_dict(torch.load(args.pretrained_file))
    print("Pre-trained model loaded")
    sys.stdout.flush()

CE_weights = torch.FloatTensor([4.5, 0.5])

if torch.cuda.is_available():
    print("Using GPU")
    model.cuda()
    CE_weights = CE_weights.cuda()
    if args.conditional:
示例#17
0
def training(num_epochs,
             batch_size,
             bars,
             pianoroll,
             transpose,
             verbose,
             save_location,
             resume=False,
             initialize=False):
    def loss_func(x_hat, x, mean, std_deviation, beta):
        bce = F.binary_cross_entropy(x_hat, x, reduction='sum')
        bce = bce / (batch_size * bars)
        kl = -0.5 * torch.sum(1 + torch.log(std_deviation**2) - mean**2 -
                              std_deviation**2)
        kl = kl / batch_size
        loss = bce + beta * kl

        if verbose:
            print(
                "\t\tCross Entropy: \t{}\n\t\tKL-Divergence: \t{}\n\t\tFinal Loss: \t{}"
                .format(bce, kl, loss))

        return loss

    # no teacher forcing is used during evaluation, the model has to rely on its own previous outputs
    def evaluate():
        vae.eval()
        avg_loss = 0
        i = 0

        for batch in eval_loader:
            i += 1
            if verbose:
                print("\tbatch " + str(i) + ":")
            batch = batch.to(device)
            vae_output, mu, log_var = vae(batch)
            loss = criterion(vae_output, batch, mu, log_var, beta)
            avg_loss += loss.item()
        avg_loss = avg_loss / (len(eval_set) / batch_size)
        vae.train()
        return avg_loss

    if save_location[0] != '/' and save_location[0] != '~':
        save_location = data.path_to_root + '/' + save_location

    stride = 1 if transpose else 3
    train_set = data.FinalDataset('train',
                                  bars,
                                  stride=stride,
                                  pianoroll=pianoroll,
                                  transpose=transpose)
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=batch_size,
                                               shuffle=True)
    eval_set = data.FinalDataset('validation',
                                 bars,
                                 stride=stride,
                                 pianoroll=pianoroll,
                                 transpose=transpose)
    eval_loader = torch.utils.data.DataLoader(eval_set, batch_size=batch_size)

    loss_list = []
    resume_epoch = 0
    min_eval_loss = sys.maxsize  # does not remember min loss if trainig is aborted and resumed, but all losses are saved in loss_list()

    if initialize:
        vae = musicVAE.get_pretrained_vae(bars=bars,
                                          pianoroll=pianoroll,
                                          transpose=transpose,
                                          verbose=verbose)
        if verbose:
            print("loaded the pretrained weights")
    else:
        vae = model.VAE(bars, pianoroll)
    vae.train()

    if resume:
        checkpoint = musicVAE.get_checkpoint(str(train_set), save_location)
        prefix = "last_"
        loss_list = checkpoint[prefix + 'loss_list']
        resume_epoch = checkpoint[prefix + 'epoch']

        print("found checkpoint\n\tresuming training at epoch " +
              str(resume_epoch) +
              ("\n\tlast train loss: {}\n\tlast eval loss: {}\n".format(
                  loss_list[-1][0], loss_list[-1][1]) if loss_list else "\n"))

        vae.load_state_dict(checkpoint[prefix + 'vae_state_dict'])

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    vae.to(device)
    print("using device " + str(device) + "\n")

    optimizer = optim.Adam(vae.parameters(), lr=learning_rate)
    if resume:
        optimizer.load_state_dict(checkpoint[prefix + 'optim_state_dict'])
    criterion = loss_func

    #compute initial evaluation loss before training
    if not resume:
        if verbose:
            print("Initial Evaluation\n\n")
        initial_loss = evaluate()
        print("\nInitial evaluation loss: " + str(initial_loss))

    for epoch in range(resume_epoch, num_epochs):
        if verbose:
            print("\n\n\n\nSTARTING EPOCH " + str(epoch) + "\n")
        avg_loss = 0
        avg_correct = 0
        i = 0

        for batch in train_loader:
            i += 1
            batch = batch.to(device)  # .to() returns a copy for tensors!

            optimizer.zero_grad()

            if verbose:
                print("\tbatch " + str(i) + ":")
            vae.set_ground_truth(batch)
            vae_output, mu, log_var = vae(batch)
            loss = criterion(vae_output, batch, mu, log_var, beta)

            # to free some memory:
            vae_output = None
            batch = None

            loss.backward()
            optimizer.step()

            avg_loss += loss.item()

        if verbose:
            print("\n\n\nEvaluation of epoch " + str(epoch) + "\n")

        avg_loss = avg_loss / (len(train_set) / batch_size)
        eval_loss = evaluate()
        timestamp = datetime.now()
        timestamp = "{}.{}.  -  {}:{}".format(timestamp.day, timestamp.month,
                                              timestamp.hour, timestamp.minute)

        print("EPOCH " + str(epoch) + "\t\t(finished at:   " + timestamp + ")")
        print("\ttraining loss: \t\t" + str(avg_loss))
        loss_list.append((avg_loss, eval_loss))
        print("\tevaluation loss: \t" + str(eval_loss) + "\n")
        best = False
        if eval_loss <= min_eval_loss:
            min_eval_loss = eval_loss
            best = True

        if verbose:
            print(("\n\tsaving checkpoint.."))
        musicVAE.save(vae, optimizer, loss_list, epoch + 1, str(train_set),
                      best, save_location)
示例#18
0
文件: main.py 项目: LeChangAlex/vaal
def main(args):
    if args.dataset == 'cifar10':
        test_dataloader = data.DataLoader(datasets.CIFAR10(
            args.data_path,
            download=True,
            transform=cifar_transformer(),
            train=False),
                                          batch_size=args.batch_size,
                                          drop_last=False)

        train_dataset = CIFAR10(args.data_path)

        args.num_images = 50000
        args.budget = 2500
        args.initial_budget = 5000
        args.num_classes = 10
    elif args.dataset == 'cifar100':
        test_dataloader = data.DataLoader(datasets.CIFAR100(
            args.data_path,
            download=True,
            transform=cifar_transformer(),
            train=False),
                                          batch_size=args.batch_size,
                                          drop_last=False)

        train_dataset = CIFAR100(args.data_path)

        args.num_images = 50000
        args.budget = 2500
        args.initial_budget = 5000
        args.num_classes = 100

    elif args.dataset == 'imagenet':
        test_dataloader = data.DataLoader(datasets.ImageFolder(
            args.data_path, transform=imagenet_transformer()),
                                          drop_last=False,
                                          batch_size=args.batch_size)

        train_dataset = ImageNet(args.data_path)

        args.num_images = 1281167
        args.budget = 64060
        args.initial_budget = 128120
        args.num_classes = 1000
    else:
        raise NotImplementedError

    all_indices = set(np.arange(args.num_images))
    initial_indices = random.sample(all_indices, args.initial_budget)
    sampler = data.sampler.SubsetRandomSampler(initial_indices)

    # dataset with labels available
    querry_dataloader = data.DataLoader(train_dataset,
                                        sampler=sampler,
                                        batch_size=args.batch_size,
                                        drop_last=True)

    args.cuda = args.cuda and torch.cuda.is_available()
    solver = Solver(args, test_dataloader)

    splits = [0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4]
    #splits = [0.4]

    current_indices = list(initial_indices)

    accuracies = []

    for split in splits:
        # need to retrain all the models on the new images
        # re initialize and retrain the models
        task_model = vgg.vgg16_bn(num_classes=args.num_classes)
        vae = model.VAE(args.latent_dim)
        discriminator = model.Discriminator(args.latent_dim,
                                            args.num_classes + 1)

        unlabeled_indices = np.setdiff1d(list(all_indices), current_indices)
        unlabeled_sampler = data.sampler.SubsetRandomSampler(unlabeled_indices)
        unlabeled_dataloader = data.DataLoader(train_dataset,
                                               sampler=unlabeled_sampler,
                                               batch_size=args.batch_size,
                                               drop_last=False)

        # train the models on the current data
        acc, vae, discriminator = solver.train(querry_dataloader, task_model,
                                               vae, discriminator,
                                               unlabeled_dataloader, args)

        print('Final accuracy with {}% of data is: {:.2f}'.format(
            int(split * 100), acc))
        accuracies.append(acc)

        sampled_indices = solver.sample_for_labeling(vae, discriminator,
                                                     unlabeled_dataloader)
        current_indices = list(current_indices) + list(sampled_indices)
        sampler = data.sampler.SubsetRandomSampler(current_indices)
        querry_dataloader = data.DataLoader(train_dataset,
                                            sampler=sampler,
                                            batch_size=args.batch_size,
                                            drop_last=True)

    torch.save(accuracies, os.path.join(args.out_path, args.log_name))
示例#19
0
文件: train.py 项目: ashwindcruz/dgm
print "%d hidden dimensions" % nhidden
nlatent = int(args['--nlatent'])
print "%d latent VAE dimensions" % nlatent
zcount = int(args['--vae-samples'])
print "Using %d VAE samples per instance" % zcount
nmap = int(args['--nmap'])
print "Using %d planar flow mappings" % nmap

log_interval = int(args['--log-interval'])
print "Recording training and testing ELBO every %d batches" % log_interval

# Setup training parameters
batchsize = int(args['--batchsize'])
print "Using a batchsize of %d instances" % batchsize

vae = model.VAE(d, nhidden, nlatent, zcount, nmap)
opt = optimizers.Adam()
opt.setup(vae)
opt.add_hook(chainer.optimizer.GradientClipping(4.0))

# Move to GPU
gpu_id = int(args['--device'])
if gpu_id >= 0:
    cuda.check_cuda_available(
    )  # comment out to surpress an unncessarry warning
if gpu_id >= 0:
    xp = cuda.cupy
    vae.to_gpu(gpu_id)
else:
    xp = np
示例#20
0
文件: main.py 项目: gnoluna/vaal
def main(args):
    if args.dataset == 'cifar10':
        test_dataloader = data.DataLoader(datasets.CIFAR10(
            args.data_path,
            download=True,
            transform=cifar_transformer(),
            train=False),
                                          batch_size=args.batch_size,
                                          drop_last=False)

        train_dataset = CIFAR10(args.data_path)

        args.num_images = 50000
        args.budget = 2500
        args.initial_budget = 5000
        args.num_classes = 10
    elif args.dataset == 'cifar100':
        test_dataloader = data.DataLoader(datasets.CIFAR100(
            args.data_path,
            download=True,
            transform=cifar_transformer(),
            train=False),
                                          batch_size=args.batch_size,
                                          drop_last=False)

        train_dataset = CIFAR100(args.data_path)

        args.num_images = 50000
        args.budget = 2500
        args.initial_budget = 5000
        args.num_classes = 100

    elif args.dataset == 'imagenet':
        test_dataloader = data.DataLoader(datasets.ImageFolder(
            args.data_path, transform=imagenet_transformer()),
                                          drop_last=False,
                                          batch_size=args.batch_size)

        train_dataset = ImageNet(args.data_path)

        args.num_images = 1281167
        args.budget = 64060
        args.initial_budget = 128120
        args.num_classes = 1000
    elif args.dataset == 'semeval':
        train_dataset = SemEvalRes('train')
        test_dataloader = data.DataLoader(SemEvalRes('test',
                                                     mlb=train_dataset.mlb),
                                          batch_size=args.batch_size,
                                          drop_last=False)

        N_samples = len(train_dataset)
        # print(N_samples)
        args.num_images = N_samples
        args.budget = int(0.05 * N_samples)
        args.initial_budget = int(0.1 * N_samples)
        args.num_classes = train_dataset.n_classes

        emb_size = train_dataset.emb_size
    else:
        raise NotImplementedError

    all_indices = set(np.arange(args.num_images))
    initial_indices = random.sample(all_indices, args.initial_budget)
    sampler = data.sampler.SubsetRandomSampler(initial_indices)

    # dataset with labels available
    querry_dataloader = data.DataLoader(train_dataset,
                                        sampler=sampler,
                                        batch_size=args.batch_size,
                                        drop_last=True)

    args.cuda = args.cuda and torch.cuda.is_available()
    solver = Solver(args, test_dataloader)

    # splits = [0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4]
    splits = []
    split = len(querry_dataloader) / N_samples

    current_indices = list(initial_indices)

    accuracies = []
    f1s = []

    while split < N_samples:
        splits.append(split)
        # need to retrain all the models on the new images
        # re initialize and retrain the models
        # task_model = vgg.vgg16_bn(num_classes=args.num_classes)
        task_model = linear.LinearModel(emb_size, args.num_classes)
        vae = model.VAE(args.latent_dim, emb_size)
        discriminator = model.Discriminator(args.latent_dim)

        unlabeled_indices = np.setdiff1d(list(all_indices), current_indices)
        unlabeled_sampler = data.sampler.SubsetRandomSampler(unlabeled_indices)
        unlabeled_dataloader = data.DataLoader(train_dataset,
                                               sampler=unlabeled_sampler,
                                               batch_size=args.batch_size,
                                               drop_last=False)

        # train the models on the current data
        acc, f1, vae, discriminator = solver.train(querry_dataloader,
                                                   task_model, vae,
                                                   discriminator,
                                                   unlabeled_dataloader)

        print('Final accuracy with {}% of data is: {:.2f}'.format(
            int(split * 100), acc))
        print('Final f1 micro with {}% of data is: {:.2f}'.format(
            int(split * 100), f1))
        accuracies.append(acc)
        f1s.append(f1)

        sampled_indices = solver.sample_for_labeling(vae, discriminator,
                                                     unlabeled_dataloader)
        current_indices = list(current_indices) + list(sampled_indices)
        sampler = data.sampler.SubsetRandomSampler(current_indices)
        querry_dataloader = data.DataLoader(train_dataset,
                                            sampler=sampler,
                                            batch_size=args.batch_size,
                                            drop_last=True)

    perf = {'labeled_percentage': splits, 'accuracies': accuracies, 'f1': f1s}
    torch.save(perf, os.path.join(args.out_path, args.log_name))
示例#21
0
args.idx2word = {v: k for k, v in data['dict']['src'].items()}

training_data = DataLoader(data['train'],
                           args.max_len,
                           args.batch_size,
                           cuda=use_cuda)

args.n_warmup_steps = args.n_warmup_steps if args.n_warmup_steps != 0 else training_data._stop_step

# ##############################################################################
# Build model
# ##############################################################################
import model
from optim import ScheduledOptim

vae = model.VAE(args)
if use_cuda:
    vae = vae.cuda()

criterion = torch.nn.CrossEntropyLoss()

optimizer = ScheduledOptim(
    torch.optim.Adam(vae.parameters(), betas=(0.9, 0.98), eps=1e-09),
    args.embed_dim, args.n_warmup_steps, vae.parameters(), args.clip)

# ##############################################################################
# Training
# ##############################################################################
import time
from tqdm import tqdm
示例#22
0
文件: train.py 项目: needleworm/pizza
def VAE():
    #                               Graph Part                                 #
    print("Graph initialization...")
    with tf.device(FLAGS.device_train):
        with tf.variable_scope("model", reuse=None):
            m_train = G.VAE(batch_size=FLAGS.tr_batch_size,
                            is_training=True,
                            num_keys=FLAGS.num_keys,
                            input_length=FLAGS.hidden_state_size,
                            output_length=FLAGS.predict_size,
                            learning_rate=learning_rate)

        with tf.variable_scope("model", reuse=True):
            m_valid = G.VAE(batch_size=FLAGS.val_batch_size,
                            is_training=False,
                            num_keys=FLAGS.num_keys,
                            input_length=FLAGS.hidden_state_size,
                            output_length=FLAGS.predict_size,
                            learning_rate=learning_rate)

        with tf.variable_scope("model", reuse=True):
            m_test = G.VAE(batch_size=FLAGS.test_batch_size,
                           is_training=False,
                           num_keys=FLAGS.num_keys,
                           input_length=FLAGS.hidden_state_size,
                           output_length=FLAGS.predict_size,
                           learning_rate=learning_rate)
    print("Done")

    #                               Summary Part                               #
    print("Setting up summary op...")
    loss_ph = tf.placeholder(dtype=tf.float32)
    loss_summary_op = tf.summary.scalar("loss", loss_ph)
    valid_summary_writer = tf.summary.FileWriter(logs_dir + '/valid/',
                                                 max_queue=2)
    train_summary_writer = tf.summary.FileWriter(logs_dir + '/train/',
                                                 max_queue=2)
    print("Done")

    #                               Model Save Part                            #
    print("Setting up Saver...")
    saver = tf.train.Saver()
    ckpt = tf.train.get_checkpoint_state(logs_dir)
    print("Done")

    #                               Session Part                               #
    print("Setting up Data Reader...")
    validation_dataset_reader = mt.Dataset(
        directory=test_dir,
        batch_size=FLAGS.val_batch_size,
        is_batch_zero_pad=FLAGS.is_batch_zero_pad,
        hidden_state_size=FLAGS.hidden_state_size,
        predict_size=FLAGS.predict_size,
        num_keys=FLAGS.num_keys,
        tick_interval=tick_interval,
        step=FLAGS.slice_step)
    test_dataset_reader = mt.Dataset(directory=test_dir,
                                     batch_size=FLAGS.test_batch_size,
                                     is_batch_zero_pad=FLAGS.is_batch_zero_pad,
                                     hidden_state_size=FLAGS.hidden_state_size,
                                     predict_size=FLAGS.predict_size,
                                     num_keys=FLAGS.num_keys,
                                     tick_interval=tick_interval,
                                     step=FLAGS.slice_step)
    print("done")

    sess_config = tf.ConfigProto(allow_soft_placement=True,
                                 log_device_placement=False)
    sess_config.gpu_options.allow_growth = True
    sess = tf.Session(config=sess_config)

    if ckpt and ckpt.model_checkpoint_path:  # model restore
        saver.restore(sess, ckpt.model_checkpoint_path)
        print("Model restored...")
    else:
        sess.run(tf.global_variables_initializer()
                 )  # if the checkpoint doesn't exist, do initialization

    if FLAGS.mode == "train":
        train_dataset_reader = mt.Dataset(
            directory=train_dir,
            batch_size=FLAGS.tr_batch_size,
            is_batch_zero_pad=FLAGS.is_batch_zero_pad,
            hidden_state_size=FLAGS.hidden_state_size,
            predict_size=FLAGS.predict_size,
            num_keys=FLAGS.num_keys,
            tick_interval=tick_interval,
            step=FLAGS.slice_step)
        for itr in range(MAX_MAX_EPOCH):
            feed_dict = utils.vae_run_epoch(train_dataset_reader,
                                            FLAGS.tr_batch_size, m_train, sess,
                                            dropout_rate)

            if itr % 10 == 0:
                train_loss, train_pred = sess.run(
                    [m_train.loss, m_train.predict], feed_dict=feed_dict)
                train_summary_str = sess.run(loss_summary_op,
                                             feed_dict={loss_ph: train_loss})
                train_summary_writer.add_summary(train_summary_str, itr)
                print("Step : %d  TRAINING LOSS %g" % (itr, train_loss))

            if itr % 100 == 0:
                valid_loss, valid_pred = utils.vae_validation(
                    validation_dataset_reader, FLAGS.val_batch_size, m_valid,
                    FLAGS.hidden_state_size, FLAGS.predict_size, sess,
                    logs_dir, itr, tick_interval)
                valid_summary_str = sess.run(loss_summary_op,
                                             feed_dict={loss_ph: train_loss})
                valid_summary_writer.add_summary(valid_summary_str, itr)
                print("Step : %d  VALIDATION LOSS %g" % (itr, valid_loss))

            if itr % 500 == 0:
                utils.test_model(test_dataset_reader, FLAGS.test_batch_size,
                                 m_test, FLAGS.predict_size, sess, logs_dir,
                                 itr, tick_interval, 10)
            if itr % 1000 == 0:
                saver.save(sess, logs_dir + "/model.ckpt", itr)

    if FLAGS.mode == "test":
        utils.test_model(test_dataset_reader, FLAGS.test_batch_size, m_test,
                         FLAGS.predict_size, sess, logs_dir, 9999,
                         tick_interval, 10)
示例#23
0
def main(args):
    
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    init_seeds(seed=int(time.time()))
    kwargs = {'num_workers': 2, 'pin_memory': True} if args.cuda else {}
    print(args.dataset)

    if args.dataset == 'MNIST':
        test_dataloader = data.DataLoader(
            MNIST(args.data_path, args.run_folder, transform=mnist_transformer()),
                batch_size=10000, shuffle=False, **kwargs)

        train_dataset = MNIST(args.data_path, args.run_folder, train=True, transform=mnist_transformer(), imbalance_ratio=args.imbalance_ratio)

        if args.imbalance_ratio == 100:
            args.num_images = 25711
        else:
            args.num_images = 50000

        args.budget = 125
        args.initial_budget = 125
        args.num_classes = 10
        args.num_channels = 1
        args.arch_scaler = 2
    elif args.dataset == 'SVHN':
        test_dataloader = data.DataLoader(
            SVHN(args.data_path, args.run_folder, transform=svhn_transformer()),
                batch_size=5000, shuffle=False, **kwargs)

        train_dataset = SVHN(args.data_path, args.run_folder, train=True, transform=svhn_transformer(), imbalance_ratio=args.imbalance_ratio)

        if args.imbalance_ratio == 100:
            args.num_images = 318556
        else:
            args.num_images = 500000

        args.budget = 1250
        args.initial_budget = 1250
        args.num_classes = 10
        args.num_channels = 3
        args.arch_scaler = 1
    elif args.dataset == 'cifar10':
        test_dataloader = data.DataLoader(
                datasets.CIFAR10(args.data_path, download=True, transform=cifar_transformer(), train=False),
            batch_size=args.batch_size, drop_last=False)

        train_dataset = CIFAR10(args.data_path)

        args.num_images = 50000
        args.budget = 2500
        args.initial_budget = 5000
        args.num_classes = 10
        args.num_channels = 3
    elif args.dataset == 'cifar100':
        test_dataloader = data.DataLoader(
                datasets.CIFAR100(args.data_path, download=True, transform=cifar_transformer(), train=False),
             batch_size=args.batch_size, drop_last=False)

        train_dataset = CIFAR100(args.data_path)

        args.num_images = 50000
        args.budget = 2500
        args.initial_budget = 5000
        args.num_classes = 100
        args.num_channels = 3
    elif args.dataset == 'ImageNet':
        test_dataloader = data.DataLoader(
            ImageNet(args.data_path + '/val', transform=imagenet_test_transformer()),
                batch_size=args.batch_size, shuffle=False, drop_last=False, **kwargs)

        if args.imbalance_ratio == 100:
            train_dataset = ImageNet(args.data_path + '/train_ir_100', transform=imagenet_train_transformer())
            args.num_images = 645770
        else:
            train_dataset = ImageNet(args.data_path + '/train', transform=imagenet_train_transformer())
            args.num_images = 1281167

        args.budget = 64000
        args.initial_budget = 64000
        args.num_classes = 1000
        args.num_channels = 3
        args.arch_scaler = 1
    else:
        raise NotImplementedError

    all_indices = set(np.arange(args.num_images))
    initial_indices = random.sample(all_indices, args.initial_budget)
    sampler = data.sampler.SubsetRandomSampler(initial_indices)
    #print(args.batch_size, sampler)
    # dataset with labels available
    querry_dataloader = data.DataLoader(train_dataset, sampler=sampler,
            batch_size=args.batch_size, drop_last=False, **kwargs)
    print('Sampler size =', len(querry_dataloader))
    solver = Solver(args, test_dataloader)

    splits = range(1,11)

    current_indices = list(initial_indices)

    accuracies = []
    
    for split in splits:
        print("Split =", split)
        # need to retrain all the models on the new images
        # re initialize and retrain the models
        #task_model = vgg.vgg16_bn(num_classes=args.num_classes)
        if args.dataset == 'MNIST':
            task_model = model.LeNet(num_classes=args.num_classes)
        elif args.dataset == 'SVHN':
            task_model = resnet.resnet10(num_classes=args.num_classes)
        elif args.dataset == 'ImageNet':
            task_model = resnet.resnet18(num_classes=args.num_classes)
        else:
            print('WRONG DATASET!')
        # loading pretrained
        if args.pretrained:
            print("Loading pretrained model", args.pretrained)
            checkpoint = torch.load(args.pretrained)
            task_model.load_state_dict({k: v for k, v in checkpoint['state_dict'].items() if 'fc' not in k}, strict=False) # copy all but last linear layers
        #
        vae = model.VAE(z_dim=args.latent_dim, nc=args.num_channels, s=args.arch_scaler)
        discriminator = model.Discriminator(z_dim=args.latent_dim, s=args.arch_scaler)
        #print("Sampling starts")
        unlabeled_indices = np.setdiff1d(list(all_indices), current_indices)
        unlabeled_sampler = data.sampler.SubsetRandomSampler(unlabeled_indices)
        unlabeled_dataloader = data.DataLoader(train_dataset, sampler=unlabeled_sampler,
                batch_size=args.batch_size, drop_last=False, **kwargs)
        #print("Train starts")
        # train the models on the current data
        acc, vae, discriminator = solver.train(querry_dataloader,
                                               task_model, 
                                               vae, 
                                               discriminator,
                                               unlabeled_dataloader)


        print('Final accuracy with {}% of data is: {:.2f}'.format(int(split*100.0*args.budget/args.num_images), acc))
        accuracies.append(acc)

        sampled_indices = solver.sample_for_labeling(vae, discriminator, unlabeled_dataloader)
        current_indices = list(current_indices) + list(sampled_indices)
        sampler = data.sampler.SubsetRandomSampler(current_indices)
        querry_dataloader = data.DataLoader(train_dataset, sampler=sampler,
                batch_size=args.batch_size, drop_last=False, **kwargs)

    torch.save(accuracies, os.path.join(args.out_path, args.log_name))
示例#24
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--cuda",
                        default=False,
                        action="store_true",
                        help="Enable CUDA")
    parser.add_argument("--data_dir", help="Folder where the data is located")
    parser.add_argument("--epochs",
                        type=int,
                        help='Number of times to iterate the whole dataset')
    parser.add_argument("--visual_every",
                        default=200,
                        type=int,
                        help='Display faces every n batches')
    parser.add_argument("--z_dim",
                        type=int,
                        default=200,
                        help='Dimensions of latent space')
    parser.add_argument("--r_loss_factor",
                        type=float,
                        default=10000.0,
                        help='r_loss factor')
    parser.add_argument("--lr",
                        type=float,
                        default=0.002,
                        help='Learning rate')
    parser.add_argument("--batch_size",
                        default=32,
                        type=int,
                        help='Batch size')
    parser.add_argument("--load",
                        type=str,
                        default='',
                        help='Load pretrained weights')
    args = parser.parse_args()

    # data where the images are located
    data_dir = args.data_dir
    assert isinstance(data_dir, str)
    assert isinstance(args.epochs, int)
    assert isinstance(args.visual_every, int)
    assert isinstance(args.z_dim, int)
    assert isinstance(args.r_loss_factor, float)
    assert isinstance(args.lr, float)
    assert isinstance(args.batch_size, int)

    # use CPU or GPU
    device = torch.device("cuda" if args.cuda else "cpu")

    # prepare data
    train_transforms = transforms.Compose([
        # transforms.Resize((176, 144)),
        transforms.Resize((128, 112)),
        transforms.ToTensor(),
    ])

    train_data = datasets.ImageFolder(data_dir, transform=train_transforms)
    trainloader = torch.utils.data.DataLoader(train_data,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              pin_memory=True)

    images, labels = next(iter(trainloader))

    # create model
    input_shape = next(iter(trainloader))[0].shape
    vae = model.VAE(input_shape[-3:], args.z_dim).to(device)
    print(vae)  # print for feedback

    # load previous weights (if any)
    if args.load is not '':
        vae.load_state_dict(torch.load(args.load)['state_dict'])
        print("Weights loaded: {}".format(args.load))

    # create tensorboard writer
    writer = SummaryWriter(comment='-' + 'VAE' + str(args.z_dim))

    optimizer = optim.Adam(vae.parameters(), lr=args.lr)

    # generate random points in latent space so we can see how the network is training
    latent_space_test_points = np.random.normal(scale=1.0,
                                                size=(16, args.z_dim))
    latent_space_test_points_v = torch.Tensor(latent_space_test_points).to(
        device)

    batch_iterations = 0
    training_losses = []
    vae.train()
    for e in range(args.epochs):
        epoch_loss = []
        for images, labels in trainloader:
            images_v = images.to(device)

            optimizer.zero_grad()

            mu_v, log_var_v, images_out_v = vae(images_v)
            r_loss_v = r_loss(images_out_v, images_v)
            kl_loss_v = kl_loss(mu_v, log_var_v)
            loss = kl_loss_v + r_loss_v * args.r_loss_factor
            loss.backward()
            optimizer.step()

            epoch_loss.append(loss.item())

            if batch_iterations % args.visual_every == 0:
                # print loss
                print("Batch: {}\tLoss: {}".format(
                    batch_iterations + e * len(trainloader) / args.batch_size,
                    loss.item()))
                writer.add_scalar('loss',
                                  np.mean(epoch_loss[-args.visual_every:]),
                                  batch_iterations)

            batch_iterations = batch_iterations + 1

        else:
            training_losses.append(np.mean(epoch_loss))
            if min(training_losses) == training_losses[-1]:
                vae.save('vae-' + str(args.z_dim) + '.dat')

            vae.eval()

            generated_imgs_v = vae.forward_decoder(
                latent_space_test_points_v).detach()
            imgs_grid = utils.make_grid(generated_imgs_v)

            writer.add_image('preview-1',
                             imgs_grid.cpu().numpy(), batch_iterations)

            vae.train()
    label_thresh = 4  # include only a subset of MNIST classes
    idx = y_train < label_thresh  # only use digits 0, 1, 2, ...
elif "mnist2" in model_folder:
    label_thresh = 2  # include only a subset of MNIST classes
    idx = y_train == label_thresh  # only use digits 0, 1, 2, ...
else:
    label_thresh = 1  # include only a subset of MNIST classes
    idx = y_train == label_thresh  # only use digits 0, 1, 2, ...

num_classes = y_train[idx].unique().numel()
x_train = x_train[idx]
y_train = y_train[idx]
N = x_train.shape[0]

# load model
net = model.VAE(x_train, layers, num_components=num_components, device=device)
ckpt = load_checkpoint(model_folder + 'best.pth.tar', net)
net.init_std(x_train,
             gmm_mu=ckpt['gmm_means'],
             gmm_cv=ckpt['gmm_cv'],
             weights=ckpt['weights'])
saved_dict = ckpt['state_dict']
new_dict = net.state_dict()
new_dict.update(saved_dict)
net.load_state_dict(new_dict)
net.eval()

with torch.no_grad():
    z = torch.chunk(net.encoder(x_train.to(device)), chunks=2, dim=-1)[0]
z_data = custom_dataset(z)
z_loader = torch.utils.data.DataLoader(z_data,
示例#26
0
def main(args):
    if args.dataset == 'cifar10':
        test_dataloader = data.DataLoader(datasets.CIFAR10(
            args.data_path,
            download=True,
            transform=cifar_transformer(),
            train=False),
                                          batch_size=args.batch_size,
                                          drop_last=False)

        # train_dataset = CIFAR10(args.data_path)
        querry_dataloader = data.DataLoader(CIFAR10(args.data_path),
                                            batch_size=args.batch_size,
                                            drop_last=True)
        args.num_images = 50000
        # args.num_val = 5000
        # args.budget = 2500
        # args.initial_budget = 5000
        args.num_classes = 10
    elif args.dataset == 'cifar100':
        test_dataloader = data.DataLoader(datasets.CIFAR100(
            args.data_path,
            download=True,
            transform=cifar_transformer(),
            train=False),
                                          batch_size=args.batch_size,
                                          drop_last=False)

        train_dataset = CIFAR100(args.data_path)
        querry_dataloader = data.DataLoader(CIFAR100(args.data_path),
                                            batch_size=args.batch_size,
                                            shuffle=True,
                                            drop_last=True)

        args.num_val = 5000
        args.num_images = 50000
        args.budget = 2500
        args.initial_budget = 5000
        args.num_classes = 100

    elif args.dataset == 'tinyimagenet':
        test_dataloader = data.DataLoader(TinyImageNet(
            args.data_path, transform=tinyimagenet_transform(), train=False),
                                          batch_size=args.batch_size,
                                          drop_last=False)

        querry_dataloader = data.DataLoader(TinyImageNet(
            args.data_path, transform=tinyimagenet_transform(), train=True),
                                            shuffle=True,
                                            batch_size=args.batch_size,
                                            drop_last=True)
        args.num_classes = 200
        args.num_images = 100000

    elif args.dataset == 'imagenet':
        test_dataloader = data.DataLoader(datasets.ImageFolder(
            args.data_path, transform=imagenet_transformer()),
                                          drop_last=False,
                                          batch_size=args.batch_size)

        train_dataset = ImageNet(args.data_path)

        args.num_val = 128120
        args.num_images = 1281167
        args.budget = 64060
        args.initial_budget = 128120
        args.num_classes = 1000
    else:
        raise NotImplementedError

    args.cuda = torch.cuda.is_available()

    # all_indices = set(np.arange(args.num_images))
    # val_indices = random.sample(all_indices, args.num_val)
    # all_indices = np.setdiff1d(list(all_indices), val_indices)

    # initial_indices = random.sample(list(all_indices), args.initial_budget)
    # sampler = data.sampler.SubsetRandomSampler(initial_indices)
    # sampler = data.sampler.SubsetRandomSampler(list(all_indices))
    # val_sampler = data.sampler.SubsetRandomSampler(val_indices)

    # dataset with labels available
    # querry_dataloader = data.DataLoader(train_dataset, sampler=sampler,
    #         batch_size=args.batch_size, drop_last=True)
    # val_dataloader = data.DataLoader(train_dataset, sampler=val_sampler,
    #        batch_size=args.batch_size, drop_last=False)
    val_dataloader = None

    args.cuda = args.cuda and torch.cuda.is_available()
    solver = Solver(args, test_dataloader)

    # splits = [0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4]
    splits = [1.]

    # current_indices = list(initial_indices)

    # accuracies = []

    print('==> Building models...')

    # task_model = vgg.vgg16_bn(num_classes=args.num_classes)
    task_model = resnet.ResNet34(num_classes=args.num_classes)
    # task_model = model.Approximator(args.latent_dim, args.num_classes)
    vae = model.VAE(args.latent_dim)
    discriminator = model.Discriminator(args.latent_dim)

    if args.cuda:
        vae = vae.cuda()
        discriminator = discriminator.cuda()
        task_model = task_model.cuda()
        vae = torch.nn.DataParallel(vae)
        discriminator = torch.nn.DataParallel(discriminator)
        task_model = torch.nn.DataParallel(task_model)

    for epoch in range(args.train_epochs):

        # unlabeled_indices = np.setdiff1d(list(all_indices), current_indices)
        # unlabeled_sampler = data.sampler.SubsetRandomSampler(unlabeled_indices)
        # unlabeled_dataloader = data.DataLoader(train_dataset,
        #         sampler=unlabeled_sampler, batch_size=args.batch_size, drop_last=False)
        unlabeled_dataloader = None
        print('\nEpoch: %d' % epoch)
        # train the models on the current data
        acc, vae, discriminator, task_model = solver.train(
            epoch, querry_dataloader, val_dataloader, task_model, vae,
            discriminator, unlabeled_dataloader)
示例#27
0
lr = 1
c_lr = args.c_lr

batch_size = 8

epochs = args.epochs

total_time = 0

min_dev_loss = 9999999999999999
min_epoch = 0
d_epoch = 1

pre_vae = model.VAE(style_dim=TOTAL_SPK_NUM,
                    latent_dim=latent_dim,
                    vae_type=args.model_type)
pre_vae.load_state_dict(torch.load(args.baseline))
pre_vae.cuda()
pre_vae.eval()

spk_C = model.LatentClassifier(latent_dim=latent_dim, label_num=TOTAL_SPK_NUM)
spk_C.cuda()
spk_C_opt = optim.Adam(spk_C.parameters(), lr=c_lr)
# spk_C_sch = optim.lr_scheduler.LambdaLR(optimizer=spk_C_opt, lr_lambda=lambda epoch: c_lr*(-(1e-2/(epochs+1))*epoch+1e-2))
print(calc_parm_num(spk_C))
print(spk_C)

torch.save(spk_C.state_dict(),
           os.path.join(model_dir, "si_{}.pt".format(epochs)))
示例#28
0
lr = 1
c_lr = args.c_lr

batch_size = 8

epochs = args.epochs

total_time = 0

min_dev_loss = 9999999999999999
min_epoch = 0
d_epoch = 1

pre_vae = model.VAE(style_dim=TOTAL_SPK_NUM,
                    latent_dim=latent_dim,
                    vae_type=args.model_type,
                    weight_sharing=args.ws)
pre_vae.load_state_dict(torch.load(args.baseline))
pre_vae.cuda()
pre_vae.eval()

spk_C = model.LatentClassifier(latent_dim=latent_dim, label_num=TOTAL_SPK_NUM)
spk_C.cuda()
spk_C_opt = optim.Adam(spk_C.parameters(), lr=c_lr)
# spk_C_sch = optim.lr_scheduler.LambdaLR(optimizer=spk_C_opt, lr_lambda=lambda epoch: c_lr*(-(1e-2/(epochs+1))*epoch+1e-2))
print(calc_parm_num(spk_C))
print(spk_C)

torch.save(spk_C.state_dict(),
           os.path.join(model_dir, "si_{}.pt".format(epochs)))
示例#29
0
## Data
x_train, y_train, N, x_test, y_test, N_test = data.load_data(
    experiment_parameters, root="./data")
train_loader, test_loader = data.data_split(x_train, x_test, batch_size)

# Fit mean network
if experiment_parameters["dataset"] == "bodies":
    model = model.VAE_bodies(
        x_train,
        layers,
        num_components=experiment_parameters["num_components"],
        device=device)
else:
    model = model.VAE(x_train,
                      layers,
                      num_components=experiment_parameters["num_components"],
                      device=device)
model.fit_mean(train_loader, num_epochs=5, num_cycles=1, max_kl=1)

# fit std
model.init_std(x_train,
               inv_maxstd=experiment_parameters["inv_maxstd"],
               beta_constant=torch.Tensor(
                   [experiment_parameters["beta_constant"]]),
               component_overwrite=experiment_parameters["num_components"])
model.fit_std(train_loader, num_epochs=experiment_parameters["std_epochs"])

save_checkpoint(
    {
        'state_dict': model.state_dict(),
        'gmm_means': model.gmm_means,
示例#30
0
def main(args):
    if args.dataset == 'cifar10':

        all_indices = set(np.arange(10000))
        initial_indices = random.sample(all_indices, 2000)
        test_sampler = data.sampler.SubsetRandomSampler(initial_indices)

        test_dataloader = data.DataLoader(datasets.CIFAR10(
            args.data_path,
            download=True,
            transform=cifar_transformer(),
            train=False),
                                          batch_size=args.batch_size,
                                          drop_last=False,
                                          sampler=test_sampler)
        '''
        The length is still (orig_length)
        But the times it will iterate through is only #random indices now
        '''
        # print(len(test_dataloader.dataset))
        # total = 0
        # for i, batch in enumerate(tqdm(test_dataloader)):
        #     print("good")
        #     total += len(batch[0])
        # print("total was {}".format(total))
        # construct a new dataset, from these iterated ones

        train_dataset = CIFAR10(args.data_path)

        args.num_images = 5000  #a type of curriculum learning could be useful here!
        args.budget = 250
        args.initial_budget = 2000
        args.num_classes = 10
    elif args.dataset == 'cifar100':
        test_dataloader = data.DataLoader(datasets.CIFAR100(
            args.data_path,
            download=True,
            transform=cifar_transformer(),
            train=False),
                                          batch_size=args.batch_size,
                                          drop_last=False)

        train_dataset = CIFAR100(args.data_path)

        args.num_images = 50000
        args.budget = 2500
        args.initial_budget = 5000
        args.num_classes = 100

    elif args.dataset == 'imagenet':
        test_dataloader = data.DataLoader(datasets.ImageFolder(
            args.data_path, transform=imagenet_transformer()),
                                          drop_last=False,
                                          batch_size=args.batch_size)

        train_dataset = ImageNet(args.data_path)

        args.num_images = 1281167
        args.budget = 64060
        args.initial_budget = 128120
        args.num_classes = 1000
    else:
        raise NotImplementedError

    all_indices = set(np.arange(args.num_images))
    initial_indices = random.sample(all_indices, args.initial_budget)
    sampler = data.sampler.SubsetRandomSampler(initial_indices)

    # dataset with labels available
    querry_dataloader = data.DataLoader(train_dataset,
                                        sampler=sampler,
                                        batch_size=args.batch_size,
                                        drop_last=True)

    # print("in main")
    # print(len(querry_dataloader.dataset))

    args.cuda = args.cuda and torch.cuda.is_available()
    solver = Solver(args, test_dataloader)

    splits = [
        0.4, 0.45, 0.5
    ]  #splits actually has no effect on anything (just for formatting)!

    current_indices = list(initial_indices)

    accuracies = []

    # let's say we give it just 10% of the initial dataset, and say the rest was "unlabelled". We will see how it performs

    for split in tqdm(splits):
        # need to retrain all the models on the new images
        # re initialize and retrain the models
        # task_model = vgg.vgg16_bn(num_classes=args.num_classes)
        task_model = model.SimpleTaskModel(
            args.latent_dim * 2, args.num_classes
        )  # 2 for the different params (variance, mean); you are trying to predict
        vae = model.VAE(args.latent_dim)
        discriminator = model.Discriminator(args.latent_dim)

        # OK, so they do actually have a strong separation: we DO keep track of which indices have been sampled so far
        unlabeled_indices = np.setdiff1d(list(all_indices), current_indices)
        unlabeled_sampler = data.sampler.SubsetRandomSampler(unlabeled_indices)
        unlabeled_dataloader = data.DataLoader(train_dataset,
                                               sampler=unlabeled_sampler,
                                               batch_size=args.batch_size,
                                               drop_last=False)

        # train the models on the current data
        acc, vae, discriminator = solver.train(querry_dataloader, task_model,
                                               vae, discriminator,
                                               unlabeled_dataloader)

        print('Final accuracy with {}% of data is: {:.2f}'.format(
            int(split * 100), acc))
        accuracies.append(acc)

        sampled_indices = solver.sample_for_labeling(vae, discriminator,
                                                     unlabeled_dataloader)
        current_indices = list(current_indices) + list(sampled_indices)
        sampler = data.sampler.SubsetRandomSampler(current_indices)
        querry_dataloader = data.DataLoader(train_dataset,
                                            sampler=sampler,
                                            batch_size=args.batch_size,
                                            drop_last=True)

    torch.save(accuracies, os.path.join(args.out_path, args.log_name))