Beispiel #1
0
    for i in range(DATASET_SIZE):
        #print(i,":",tokenDset.getRand(i),img.size())
        optimizerDALLE.zero_grad()
        img, strs = cap[i]
        #print(img.size())
        img = img.unsqueeze(0).cuda()
        if i % 10 == 0:
            print("DALLE epoch {} / {}".format(i + epoch * DATASET_SIZE,
                                               EPOCHS * DATASET_SIZE))
        try:
            textToken, mask = fixlen([tokenDset.getRand(i)])
        except KeyError:
            continue
        loss = dalle(textToken.cuda(), img, mask=mask.cuda(), return_loss=True)
        DALLEloss.append(loss.detach().cpu().numpy())
        loss.backward()
        optimizerDALLE.step()

np.savetxt("dalleloss.csv", np.asarray(DALLEloss), delimiter=",")

# do the above for a long time with a lot of data ... then

torch.save(dalle.state_dict(), "dalle-small.pth")

test_text = "犬が地面に寝そべっている写真"

textToken, mask = fixlen([tokenDset.tokenizeList(test_text)])

images = dalle.generate_images(textToken.cuda(), mask=mask)
print(images.shape)  # (2, 3, 256, 256)
        images = images.to(device)

        mask = torch.ones_like(text).bool().to(device)

        # train and optimize a single minibatch
        optimizer.zero_grad()
        loss = dalle(text, images, mask=mask, return_loss=True)
        train_loss += loss.item()
        loss.backward()
        optimizer.step()

        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(i), len(data),
                100. * batch_idx / int(round(len(data) / batchSize)),
                loss.item() / len(i)))

        batch_idx += 1

    print('====> Epoch: {} Average loss: {:.4f}'.format(
        epoch, train_loss / len(data)))

    torch.save(dalle.state_dict(),
               "./models/" + name + "_dalle_" + str(epoch) + ".pth")

    # generate a test sample from the captions in the last minibatch
    oimgs = dalle.generate_images(text, mask=mask)
    save_image(oimgs,
               'results/' + name + '_dalle_epoch_' + str(epoch) + '.png',
               normalize=True)