Пример #1
0
# torchvision.transforms.ToTensor()
# ]))

train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=BATCH_SIZE,
                                           shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=BATCH_SIZE,
                                          shuffle=False)

optim = torch.optim.Adam(model.parameters(), lr=0.001)
crit = nn.CrossEntropyLoss()

save_id = int(time.time())
best_loss = 999.
best_model = copy.deepcopy(model.state_dict())

for ei in range(NB_EPOCHS):
    print(f"\n> Epoch {ei+1}/{NB_EPOCHS}")
    train_loss = 0.0
    eval_loss = 0.0

    model.train()
    for x, h in tqdm(train_loader):
        optim.zero_grad()
        x, h = x.to(device), h.to(device)
        target = (x * 255).long()

        pred = model(x, h)
        loss = crit(pred.view(BATCH_SIZE, 256, -1),
                    target.view(BATCH_SIZE, -1))
Пример #2
0
def train(config, mode='cifar10'):
    model_name = 'pcnn_lr:{:.5f}_nr-resnet{}_nr-filters{}'.format(config.lr, config.nr_resnet, config.nr_filters)
    try:
        os.makedirs('models')
        os.makedirs('images')
        # print('mkdir:', config.outfile)
    except OSError:
        pass

    seed = np.random.randint(0, 10000)
    print("Random Seed: ", seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed_all(seed)
    cudnn.benchmark = True

    trainset, train_loader, testset, test_loader, classes = load_data(mode=mode, batch_size=config.batch_size)
    if mode == 'cifar10' or mode == 'faces':
        obs = (3, 32, 32)
        loss_op = lambda real, fake: discretized_mix_logistic_loss(real, fake, config.nr_logistic_mix)
        sample_op = lambda x: sample_from_discretized_mix_logistic(x, config.nr_logistic_mix)
    elif mode == 'mnist':
        obs = (1, 28, 28)
        loss_op = lambda real, fake: discretized_mix_logistic_loss_1d(real, fake, config.nr_logistic_mix)
        sample_op = lambda x: sample_from_discretized_mix_logistic_1d(x, config.nr_logistic_mix)
    sample_batch_size = 25
    rescaling_inv = lambda x: .5 * x + .5

    model = PixelCNN(nr_resnet=config.nr_resnet, nr_filters=config.nr_filters,
                     input_channels=obs[0], nr_logistic_mix=config.nr_logistic_mix).cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=config.lr_decay)

    if config.load_params:
        load_part_of_model(model, config.load_params)
        print('model parameters loaded')

    def sample(model):
        model.train(False)
        data = torch.zeros(sample_batch_size, obs[0], obs[1], obs[2])
        data = data.cuda()
        with tqdm(total=obs[1] * obs[2]) as pbar:
            for i in range(obs[1]):
                for j in range(obs[2]):
                    with torch.no_grad():
                        data_v = data
                        out = model(data_v, sample=True)
                        out_sample = sample_op(out)
                        data[:, :, i, j] = out_sample.data[:, :, i, j]
                    pbar.update(1)
        return data

    print('starting training')
    for epoch in range(config.max_epochs):
        model.train()
        torch.cuda.synchronize()
        train_loss = 0.
        time_ = time.time()
        with tqdm(total=len(train_loader)) as pbar:
            for batch_idx, (data, label) in enumerate(train_loader):
                data = data.requires_grad_(True).cuda()

                output = model(data)
                loss = loss_op(data, output)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                train_loss += loss.item()
                pbar.update(1)

        deno = batch_idx * config.batch_size * np.prod(obs)
        print('train loss : %s' % (train_loss / deno), end='\t')

        # decrease learning rate
        scheduler.step()

        model.eval()
        test_loss = 0.
        with tqdm(total=len(test_loader)) as pbar:
            for batch_idx, (data, _) in enumerate(test_loader):
                data = data.requires_grad_(False).cuda()

                output = model(data)
                loss = loss_op(data, output)
                test_loss += loss.item()
                del loss, output
                pbar.update(1)
        deno = batch_idx * config.batch_size * np.prod(obs)
        print('test loss : {:.4f}, time : {:.4f}'.format((test_loss / deno), (time.time() - time_)))

        torch.cuda.synchronize()

        if (epoch + 1) % config.save_interval == 0:
            torch.save(model.state_dict(), 'models/{}_{}.pth'.format(model_name, epoch))
            print('sampling...')
            sample_t = sample(model)
            sample_t = rescaling_inv(sample_t)
            save_image(sample_t, 'images/{}_{}.png'.format(model_name, epoch), nrow=5, padding=0)
Пример #3
0
def main():
    path = 'data'
    data_name = 'CIFAR'
    batch_size = 64

    layers = 10
    kernel = 7
    channels = 128
    epochs = 25
    save_path = 'models'

    normalize = transforms.Lambda(lambda image: np.array(image) / 255.0)


    def quantisize(image, levels):
        return np.digitize(image, np.arange(levels) / levels) - 1
    discretize = transforms.Compose([
        transforms.Lambda(lambda image: quantisize(image, (channels - 1))),
        transforms.ToTensor()
    ])
    cifar_transform = transforms.Compose([normalize, discretize])

    train= datasets.CIFAR10(root=path, train=True, download=True, transform = cifar_transform)
    test= datasets.CIFAR10(root=path, train=False, download=True, transform = cifar_transform)
    
    train = data.DataLoader(train, batch_size=batch_size, shuffle=True, num_workers =0, pin_memory = True)
    test = data.DataLoader(test, batch_size=batch_size, shuffle=False, num_workers =0, pin_memory = True)


    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)
    net = PixelCNN(num_layers=layers, kernel_size=kernel, num_channels=channels).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters())
    loss_overall = []
    for i in range(epochs):
        if (i%3) == 1:
            sampling(net, i-1, channels)
        net.train(True)
        step = 0
        loss_= 0
        for images, labels in tqdm(train, desc='Epoch {}/{}'.format(i + 1, epochs)):
            images = images.to(device)
            normalized_images = images.float() / ((channels - 1))
            optimizer.zero_grad()

            output = net(normalized_images)
            loss = criterion(output, images)
            loss.backward()
            optimizer.step()

            loss_+=loss
            step+=1

        print('Epoch:'+str(i)+'       , '+ 'Average loss: ', loss_/step)
        with open("hst.txt", "a") as myfile:
            myfile.write(str(loss_/step) + '\n')
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        if(i==epochs-1):
            torch.save(net.state_dict(), save_path+'/model_'+'Last'+'.pt')
        else:
            torch.save(net.state_dict(), save_path+'/model_'+str(i)+'.pt')
        print('model saved')
    h5f = h5py.File(hdf5_path, 'r')
else:
    with h5py.File(hdf5_path, 'w') as h5f:
        for name, shape in init_vars:

            val = tf.train.load_variable(tf_path, name)

            print(val.dtype)
            print("Loading TF weight {} with shape {}, {}".format(
                name, shape, val.shape))
            torch.from_numpy(np.array(val))
            if 'model' in name:
                new_name = name.replace('/', '.')
                print(new_name)
                h5f.create_dataset(str(new_name), data=val)

    h5f = h5py.File(hdf5_path, 'r')

model = PixelCNN(nr_resnet=5,
                 nr_filters=160,
                 input_channels=3,
                 nr_logistic_mix=10)

#print(model.state_dict().keys())
converter = TF2Pytorch(h5f)
converter.load_pixelcnn()

model.load_state_dict(converter.state_dict)
torch.save(model.state_dict(), ckpt_path)
h5f.close()