Esempio n. 1
0
def train(loader, model, criterion, optimizer, epoch):
    running_loss = 0.0
    running_acc = 0.0
    total = 0
    print('epoch {0}'.format(epoch + 1))
    print('*' * 10)
    for idx, (inputs, targets) in enumerate(loader, 1):
        inputs = inputs.view(inputs.size(0), -1)
        if args.cuda:
            inputs = Variable(inputs).cuda()
            targets = Variable(targets).cuda()
        else:
            inputs = Variable(inputs)
            targets = Variable(targets)
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        running_loss += loss.data[0] * targets.size(0)
        _, pred = torch.max(outputs, 1)
        num_correct = (pred == targets).sum()
        running_acc += num_correct.data[0]
        total += targets.size(0)
        optimizer = exp_lr_scheduler(optimizer=optimizer,
                                     epoch=epoch + 1,
                                     init_lr=args.init_lr)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if idx % 500 == 0:
            print('{}/{} Loss: {:.6f}, Acc: {:.6f}'.format(
                epoch + 1, args.epochs, running_loss / total,
                running_acc / total))

    print('Finish {} epochs, Loss:{:.6f}, Acc:{:.6f}'.format(
        epoch + 1, running_loss / total, running_acc / total))
Esempio n. 2
0
def train(G, D, dataloader, mode="vanilla"):
    vis = visdom.Visdom(base_url=config.Visdom_Base_Url)
    cuda0 = torch.device('cuda:0')
    cpu0 = torch.device('cpu')

    # LOSS + OPTIMIZER

    if mode != 'vanilla':
        G_optimizer = optim.Adam(G.parameters(), lr=0.0001, betas=(0., 0.9))
        D_optimizer = optim.Adam(D.parameters(), lr=0.0001, betas=(0., 0.9))
    else:
        G_optimizer = optim.RMSprop(G.parameters(), lr=0.0001)
        D_optimizer = optim.RMSprop(D.parameters(), lr=0.0001)

    logging.info("WGAN-GP training start")

    z_fixed = torch.randn((16, 100)).view(-1, 100, 1, 1).to(cuda0)

    d_iter = 5
    g_iter = 1
    g_iter_count = 0

    num_iter = 0
    for epoch in range(10000):
        for batch in dataloader:
            batch = batch[0]
            mini_batch = batch.size()[0]

            batch = batch.to(cuda0)

            utils.exp_lr_scheduler(G_optimizer, num_iter, lr=0.0001)
            utils.exp_lr_scheduler(D_optimizer, num_iter, lr=0.0001)

            # Discriminator step
            #real_d_iter = 100 if g_iter_count < 20 or g_iter_count % 500 == 0 else d_iter
            for i in range(d_iter):
                X = batch.view(-1, 1, 32, 32)

                G_result = G(
                    torch.randn((mini_batch, 100)).view(-1, 100, 1,
                                                        1).to(cuda0))

                D.zero_grad()
                D_real_loss = D(X).mean()
                D_fake_loss = D(G_result).mean()
                D_train_loss = D_fake_loss - D_real_loss

                if mode == 'gp':
                    D_train_loss += calc_gradient_penalty(
                        D, X, G_result, mini_batch)

                D_train_loss.backward()
                D_optimizer.step()

                if mode == 'vanilla':
                    for p in D.parameters():
                        p.data.clamp_(-0.01, 0.01)

                if num_iter % 10 == 0 and i == d_iter - 1:
                    visdom_wrapper.line_pass(vis,
                                             num_iter,
                                             -D_train_loss.item(),
                                             'rgb(235, 99, 99)',
                                             "Train",
                                             name='d_loss',
                                             title='Loss')

            # Generator step
            for i in range(g_iter):
                G_result = G(
                    torch.randn((mini_batch, 100)).view(-1, 100, 1,
                                                        1).to(cuda0))
                G.zero_grad()
                G_train_loss = -D(G_result).mean()
                G_train_loss.backward()
                G_optimizer.step()

                if num_iter % 10 == 0:
                    visdom_wrapper.line_pass(vis,
                                             num_iter,
                                             G_train_loss.item(),
                                             'rgb(99, 153, 235)',
                                             "Train",
                                             name='g_loss',
                                             title='Loss')

                g_iter_count += 1

            # Other tasks
            num_iter += 1
            if num_iter % 100 == 0:
                torch.save(G, config.Models_Path + "/wgan_mnist_g")
                torch.save(D, config.Models_Path + "/wgan_mnist_d")

            if num_iter % 10 == 0:
                vis.images(G(z_fixed).to(cpu0) * 0.5 + 0.5,
                           nrow=4,
                           opts=dict(title='Generator updates'),
                           win="Generator_out")

        logging.info("WGAN-GP epoch {0} end".format(epoch))
            img = Variable(img).cuda()
            label = Variable(label).cuda()
        else:
            img = Variable(img)
            label = Variable(label)
        _, out = model(img)
        loss = criterion(out, label)

        # the loss of (i * batch_size) images
        running_loss += loss.data[0] * label.size(0)
        _, pred = torch.max(out, 1)
        num_correct = (pred == label).sum()
        running_acc += num_correct.data[0]

        optimizer = exp_lr_scheduler(optimizer=optimizer,
                                     epoch=epoch + 1,
                                     init_lr=init_lr)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % 100 == 0:
            print('{}/{} Loss: {:.6f}, Acc: {:.6f}'.format(
                epoch + 1, epochs, running_loss / (batch_size * i),
                running_acc / (batch_size * i)))

    print('Finish {} epochs, Loss:{:.6f}, Acc:{:.6f}'.format(
        epoch + 1, running_loss / len(train_mnist),
        running_acc / len(train_mnist)))

    model.eval()
Esempio n. 4
0
def train(G, D, dataloader):
    vis = visdom.Visdom(base_url=config.Visdom_Base_Url)
    cuda0 = torch.device('cuda:0')
    cpu0 = torch.device('cpu')

    # LOSS + OPTIMIZER
    BCE_loss = nn.BCELoss()
    G_optimizer = optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.9))
    D_optimizer = optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.9))

    logging.info("DCGAN training start")

    z_fixed = torch.randn((16, 100)).view(-1, 100, 1, 1).to(cuda0)

    num_iter = 0
    for epoch in range(10000):
        for batch in dataloader:
            batch = batch[0]
            mini_batch = batch.size()[0]
            batch = batch.to(cuda0)

            y_real = torch.ones(mini_batch, device=cuda0)
            y_fake = torch.zeros(mini_batch, device=cuda0)

            utils.exp_lr_scheduler(G_optimizer, num_iter, lr=0.0002)
            utils.exp_lr_scheduler(D_optimizer, num_iter, lr=0.0002)

            # Discriminator step
            X = batch.view(-1, 1, 32, 32)
            G_result = G(
                torch.randn((mini_batch, 100)).view(-1, 100, 1, 1).to(cuda0))
            D.zero_grad()
            D_real_loss = BCE_loss(D(X).squeeze(), y_real)
            D_fake_loss = BCE_loss(D(G_result).squeeze(), y_fake)
            D_train_loss = D_real_loss + D_fake_loss
            D_train_loss.backward()
            D_optimizer.step()

            if num_iter % 10 == 0:
                visdom_wrapper.line_pass(vis,
                                         num_iter,
                                         D_train_loss.item(),
                                         'rgb(235, 99, 99)',
                                         "Train",
                                         name='d_loss',
                                         title='Loss')

            # Generator step
            G_result = G(
                torch.randn((mini_batch, 100)).view(-1, 100, 1, 1).to(cuda0))
            G.zero_grad()
            G_train_loss = BCE_loss(D(G_result).squeeze(), y_real)
            G_train_loss.backward()
            G_optimizer.step()

            if num_iter % 10 == 0:
                visdom_wrapper.line_pass(vis,
                                         num_iter,
                                         G_train_loss.item(),
                                         'rgb(99, 153, 235)',
                                         "Train",
                                         name='g_loss',
                                         title='Loss')

            # Other tasks
            num_iter += 1
            if num_iter % 100 == 0:
                torch.save(G, config.Models_Path + "/dcgan_mnist_g")
                torch.save(D, config.Models_Path + "/dcgan_mnist_d")

            if num_iter % 10 == 0:
                vis.images(G(z_fixed).to(cpu0) * 0.5 + 0.5,
                           nrow=4,
                           opts=dict(title='Generator updates'),
                           win="Generator_out")

        logging.info("DCGAN epoch {0} end".format(epoch))