예제 #1
0
def train(pretrained=False):

    model = Autoencoder()
    print(model)
    trainer = pl.Trainer(
        max_epochs=1,
        gpus=None,
        progress_bar_refresh_rate=1,
        #fast_dev_run=True,
        early_stop_callback=early_stop,
        logger=logger)

    if not pretrained:
        res = trainer.fit(model)
        torch.save(model.state_dict(), 'model.pth')
    return None, trainer, None
예제 #2
0
def main():
    parser = argparse.ArgumentParser(
        description='Simple training script for training model')

    parser.add_argument(
        '--epochs', help='Number of epochs (default: 75)', type=int, default=75)
    parser.add_argument(
        '--batch-size', help='Batch size of the data (default: 16)', type=int, default=16)
    parser.add_argument(
        '--learning-rate', help='Learning rate (default: 0.001)', type=float, default=0.001)
    parser.add_argument(
        '--seed', help='Random seed (default:1)', type=int, default=1)
    parser.add_argument(
        '--data-path', help='Path for the downloaded dataset (default: ../dataset/)', default='../dataset/')
    parser.add_argument(
        '--dataset', help='Dataset name. Must be one of MNIST, STL10, CIFAR10')
    parser.add_argument(
        '--use-cuda', help='CUDA usage (default: False)', type=bool, default=False)
    parser.add_argument(
        '--network-type', help='Type of the network layers. Must be one of Conv, FC (default: FC)', default='FC')
    parser.add_argument(
        '--weight-decay', help='weight decay (L2 penalty) (default: 1e-5)', type=float, default=1e-5)
    parser.add_argument(
        '--log-interval', help='No of batches to wait before logging training status (default: 50)', type=int, default=50)
    parser.add_argument(
        '--save-model', help='For saving the current model (default: True)', type=bool, default=True)

    args = parser.parse_args()

    epochs = args.epochs  # number of epochs
    batch_size = args.batch_size  # batch size
    learning_rate = args.learning_rate  # learning rate
    torch.manual_seed(args.seed)  # seed value

    # Creating dataset path if it doesn't exist
    if args.data_path is None:
        raise ValueError('Must provide dataset path')
    else:
        data_path = args.data_path
        if not os.path.isdir(data_path):
            os.mkdir(data_path)

    # Downloading proper dataset and creating data loader
    if args.dataset == 'MNIST':
        T = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])

        train_data = torchvision.datasets.MNIST(
            data_path, train=True, download=True, transform=T)
        test_data = torchvision.datasets.MNIST(
            data_path, train=False, download=True, transform=T)

        ip_dim = 1 * 28 * 28  # input dimension
        h1_dim = int(ip_dim / 2)  # hidden layer 1 dimension
        op_dim = int(ip_dim / 4)  # output dimension
    elif args.dataset == 'STL10':
        T = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        train_data = torchvision.datasets.STL10(
            data_path, split='train', download=True, transform=T)
        test_data = torchvision.datasets.STL10(
            data_path, split='test', download=True, transform=T)

        ip_dim = 3 * 96 * 96  # input dimension
        h1_dim = int(ip_dim / 2)  # hidden layer 1 dimension
        op_dim = int(ip_dim / 4)  # output dimension
    elif args.dataset == 'CIFAR10':
        T = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        train_data = torchvision.datasets.CIFAR10(
            data_path, train=True, download=True, transform=T)
        test_data = torchvision.datasets.CIFAR10(
            data_path, train=False, download=True, transform=T)

        ip_dim = 3 * 32 * 32  # input dimension
        h1_dim = int(ip_dim / 2)  # hidden layer 1 dimension
        op_dim = int(ip_dim / 4)  # output dimension
    elif args.dataset is None:
        raise ValueError('Must provide dataset')
    else:
        raise ValueError('Dataset name must be MNIST, STL10 or CIFAR10')

    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

    # use CUDA or not
    device = 'cpu'
    if args.use_cuda is False:
        if torch.cuda.is_available():
            warnings.warn(
                'CUDA is available, please use for faster convergence')
        else:
            device = 'cpu'
    else:
        if torch.cuda.is_available():
            device = 'cuda'
        else:
            raise ValueError('CUDA is not available, please set it False')

    # Type of layer
    if args.network_type == 'FC':
        auto_encoder = Autoencoder(ip_dim, h1_dim, op_dim).to(device)
    elif args.network_type == 'Conv':
        auto_encoder = ConvolutionAE().to(device)
    else:
        raise ValueError('Network type must be either FC or Conv type')

    # Train the model
    auto_encoder.train()
    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(
        lr=learning_rate, params=auto_encoder.parameters(), weight_decay=args.weight_decay)

    for n_epoch in range(epochs):  # loop over the dataset multiple times
        reconstruction_loss = 0.0
        for batch_idx, (X, Y) in enumerate(train_loader):
            X = X.view(X.size()[0], -1)
            X = Variable(X).to(device)

            encoded, decoded = auto_encoder(X)

            optimizer.zero_grad()
            loss = criterion(X, decoded)
            loss.backward()
            optimizer.step()

            reconstruction_loss += loss.item()
            if (batch_idx + 1) % args.log_interval == 0:
                print('[%d, %5d] Reconstruction loss: %.5f' %
                      (n_epoch + 1, batch_idx + 1, reconstruction_loss / args.log_interval))
                reconstruction_loss = 0.0
    if args.save_model:
        torch.save(auto_encoder.state_dict(), "Autoencoder.pth")

    # Save real images
    data_iter = iter(test_loader)
    images, labels = data_iter.next()
    torchvision.utils.save_image(torchvision.utils.make_grid(
        images, nrow=4), 'images/actual_img.jpeg')

    # Load trained model and get decoded images
    auto_encoder.load_state_dict(torch.load('Autoencoder.pth'))
    auto_encoder.eval()
    images = images.view(images.size()[0], -1)
    images = Variable(images).to(device)
    encoded, decoded = auto_encoder(images)

    # Save decoded images
    if args.dataset == 'MNIST':
        decoded = decoded.view(decoded.size()[0], 1, 28, 28)
    elif args.dataset == 'STL10':
        decoded = decoded.view(decoded.size()[0], 3, 96, 96)
    elif args.dataset == 'CIFAR10':
        decoded = decoded.view(decoded.size()[0], 3, 32, 32)
    torchvision.utils.save_image(torchvision.utils.make_grid(
        decoded, nrow=4), 'images/decoded_img.jpeg')
예제 #3
0
파일: train_ae.py 프로젝트: vwrs/IEGAN
def train():
    ae = Autoencoder()
    # load trained model
    # model_path = ''
    # g.load_state_dict(torch.load(model_path))

    criterion = torch.nn.MSELoss()
    optimizer = optim.Adam(ae.parameters(), lr=opt.lr, weight_decay=opt.decay)

    # load dataset
    # ==========================
    kwargs = dict(num_workers=1, pin_memory=True) if cuda else {}
    dataloader = DataLoader(
        datasets.MNIST('MNIST', download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor()
                       ])),
        batch_size=opt.batch_size, shuffle=True, **kwargs
    )
    N = len(dataloader)

    # get sample batch
    dataiter = iter(dataloader)
    samples, _ = dataiter.next()
    # cuda
    if cuda:
        ae.cuda()
        criterion.cuda()
        samples = samples.cuda()
    samples = Variable(samples)

    if opt.history:
        loss_history = np.empty(N*opt.epochs, dtype=np.float32)
    # train
    # ==========================
    for epoch in range(opt.epochs):
        loss_mean = 0.0
        for i, (imgs, _) in enumerate(dataloader):
            if cuda:
                imgs = imgs.cuda()
            imgs = Variable(imgs)

            # forward & backward & update params
            ae.zero_grad()
            _, outputs = ae(imgs)
            loss = criterion(outputs, imgs)
            loss.backward()
            optimizer.step()

            loss_mean += loss.data[0]
            if opt.history:
                loss_history[N*epoch + i] = loss.data[0]
            show_progress(epoch+1, i+1, N, loss.data[0])

        print('\ttotal loss (mean): %f' % (loss_mean/N))
        # generate fake images
        _, reconst = ae(samples)
        vutils.save_image(reconst.data,
                          os.path.join(IMAGE_PATH,'%d.png' % (epoch+1)),
                          normalize=False)
    # save models
    torch.save(ae.state_dict(), MODEL_FULLPATH)
    # save loss history
    if opt.history:
        np.save('history/'+opt.name, loss_history)
예제 #4
0
def main(args):
    # Create model directory
    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)

    # Image preprocessing, normalization for the pretrained resnet
    transform = transforms.Compose([
        transforms.RandomCrop(args.crop_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    # Load vocabulary wrapper
    with open(args.vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    # Build data loader
    data_loader = get_loader(args.image_dir,
                             args.caption_path,
                             vocab,
                             transform,
                             args.batch_size,
                             shuffle=True,
                             num_workers=args.num_workers)

    # Build the models
    encoder = EncoderCNN(args.embed_size).to(device)
    autoencoder = Autoencoder(args.embed_size, args.embeddings_path,
                              args.hidden_size, len(vocab),
                              args.num_layers).to(device)
    print(len(vocab))

    # optimizer
    params = list(
        filter(
            lambda p: p.requires_grad,
            list(autoencoder.parameters())[1:] +
            list(encoder.linear.parameters())))
    # print(params)
    optimizer = torch.optim.Adam(params, lr=args.learning_rate)

    # Define summary writer
    writer = SummaryWriter()

    # Loss tracker
    best_loss = float('inf')

    # Train the models
    total_step = len(data_loader)
    for epoch in range(args.num_epochs):
        for i, (images, captions, lengths) in enumerate(data_loader):
            # print(captions)
            # Set mini-batch dataset
            images = images.to(device)
            captions = captions.to(device)
            targets = pack_padded_sequence(captions, lengths,
                                           batch_first=True)[0]

            # Forward, backward and optimize
            features = encoder(images)
            L_ling, L_vis = autoencoder(features, captions, lengths)
            loss = 0.2 * L_ling + 0.8 * L_vis  # Want visual loss to have bigger impact
            autoencoder.zero_grad()
            encoder.zero_grad()
            loss.backward()
            optimizer.step()

            # Save the model checkpoints when loss improves
            if loss.item() < best_loss:
                best_loss = loss
                print("Saving checkpoints")
                torch.save(
                    autoencoder.state_dict(),
                    os.path.join(
                        args.model_path, 'autoencoder-frozen-best.ckpt'.format(
                            epoch + 1, i + 1)))
                torch.save(
                    encoder.state_dict(),
                    os.path.join(
                        args.model_path,
                        'encoder-frozen-best.ckpt'.format(epoch + 1, i + 1)))

            # Print log info
            if i % args.log_step == 0:
                print(
                    'Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}'
                    .format(epoch, args.num_epochs, i, total_step, loss.item(),
                            np.exp(loss.item())))
                # Log train loss on tensorboard
                writer.add_scalar('frozen-loss/L_ling', L_ling.item(),
                                  epoch * total_step + i)
                writer.add_scalar('frozen-loss/L_vis', L_vis.item(),
                                  epoch * total_step + i)
                writer.add_scalar('frozen-loss/combined', loss.item(),
                                  epoch * total_step + i)

            # Save the model checkpoints
            if (i + 1) % args.save_step == 0:
                torch.save(
                    autoencoder.state_dict(),
                    os.path.join(
                        args.model_path,
                        'autoencoder-frozen-{}-{}.ckpt'.format(
                            epoch + 1, i + 1)))
                torch.save(
                    encoder.state_dict(),
                    os.path.join(
                        args.model_path,
                        'encoder-frozen-{}-{}.ckpt'.format(epoch + 1, i + 1)))
예제 #5
0
                                              batch_size=batch_size,
                                              shuffle=True)

    num_points = dataset.num_points()

    model = Autoencoder(num_points)
    if pretrained_model is not None:
        model.load_state_dict(torch.load(pretrained_model))
    model = model.to(device)
    loss_fn = MSELoss()
    optimizer = torch.optim.Adamax(model.parameters(),
                                   lr=learning_rate,
                                   eps=1e-7)

    for epoch in range(epochs):
        for batch_id, x in enumerate(data_loader):
            x = x.to(device)
            x_hat = model.forward(x)
            criterion = loss_fn(x, x_hat)

            optimizer.zero_grad()
            criterion.backward()
            optimizer.step()

        print("{}: {}".format(epoch, criterion.item()))

        if epoch % 100 == 0:
            torch.save(model.state_dict(),
                       'checkpoints/model-{}.pt'.format(epoch))

    torch.save(model.state_dict(), 'model.pt')
예제 #6
0
#     o = output.data.numpy()
#     o = np.reshape(o,(o.shape[0],-1)).tolist()
#     decode_mtx.extend(o)
#     em = embed.data.numpy()
#     em = np.reshape(em,(em.shape[0],-1)).tolist()
#     embed_mtx.extend(em)

# # test testing data
# for data in test_loader2:
#     data = Variable(data).cpu()
#     embed, output = model(data)
#     o = output.data.numpy()
#     o = np.reshape(o,(o.shape[0],-1)).tolist()
#     decode_mtx.extend(o)
#     em = embed.data.numpy()
#     em = np.reshape(em,(em.shape[0],-1)).tolist()
#     embed_mtx.extend(em)

###################################Save Model###################################
os.system('mkdir {}'.format(Info.outpath))
torch.save(model.state_dict(), Info.outpath+'model.pth')
#save loss graph
np.save(Info.outpath+'loss_arr', loss_arr)
# #save decode graph
# decode_mtx = np.transpose(np.array(decode_mtx))
# print('decode_mtx', decode_mtx)
# np.save(Info.outpath+'decode_mtx', decode_mtx)
# #save latent variable
# embed_mtx = np.transpose(np.array(embed_mtx))
# print('embed_mtx', embed_mtx)
# np.save(Info.outpath+'embed_mtx', embed_mtx)
예제 #7
0
plt.ylabel('loss')
plt.grid()
plt.savefig('./{}/loss.png'.format(log_dir))

# visualize latent space
test_dataset = MNIST(data_dir,
                     download=True,
                     train=False,
                     transform=img_transform)
test_loader = DataLoader(test_dataset, batch_size=10000, shuffle=False)

x, labels = iter(test_loader).next()
x = x.view(x.size(0), -1)

if use_gpu:
    x = Variable(x).cuda()
    z = model.encoder(x).cpu().data.numpy()
else:
    x = Variable(x)
    z = model.encoder(x).data.numpy()

plt.figure(figsize=(10, 10))
plt.scatter(z[:, 0], z[:, 1], marker='.', c=labels.numpy(), cmap=plt.cm.jet)
plt.colorbar()
plt.grid()
plt.savefig('./{}/latent_space.png'.format(log_dir))

# save result
np.save('./{}/loss_list.npy'.format(log_dir), np.array(loss_list))
torch.save(model.state_dict(), './{}/model_weights.pth'.format(log_dir))