示例#1
0
def main(opt, device):

    if not opt.nlog and not opt.test:
        sys.stdout = Logger(Path(opt.save_dir) / 'log_.txt')
    print_argument_options(opt)

    #Configure
    cuda = device.type != 'cpu'
    init_torch_seeds()

    dataset = load_datasets(opt.data, opt.batch_size, cuda, opt.workers)
    trainloader, testloader = dataset.trainloader, dataset.testloader
    opt.num_classes = dataset.num_classes
    print("Creat dataset: {}".format(opt.data))

    model = build_models(opt.model, opt.num_classes).to(device)
    print(model)
    if cuda and torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)
    print("Creat model: {}".format(opt.model))

    if opt.test:
        acc, err = __testing(opt, model, testloader, 0, device)
        print("==> Train Accuracy (%): {}\t Error rate(%): {}".format(
            acc, err))
        return

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=opt.lr,
                                weight_decay=5e-04,
                                momentum=0.9)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=opt.stepsize,
                                                gamma=opt.gamma)

    if opt.amp:
        opt.scaler = torch.cuda.amp.GradScaler(enabled=True)

    start_time = time.time()
    for epoch in range(opt.max_epoch):
        print("==> Epoch {}/{}".format(epoch + 1, opt.max_epoch))
        __training(opt, model, criterion, optimizer, trainloader, epoch,
                   device)
        scheduler.step()

        if opt.eval_freq > 0 and (epoch + 1) % opt.eval_freq == 0 or (
                epoch + 1) == opt.max_epoch:
            acc, err = __testing(opt, model, trainloader, epoch, device)
            print("==> Train Accuracy (%): {}\t Error rate(%): {}".format(
                acc, err))
            acc, err = __testing(opt, model, testloader, epoch, device)
            print("==> Test Accuracy (%): {}\t Error rate(%): {}".format(
                acc, err))
            save_model(model, epoch, name=opt.model, save_dir=opt.save_dir)

    elapsed = round(time.time() - start_time)
    elapsed = str(datetime.timedelta(seconds=elapsed))
    print("Finished. Total elapsed time (h:m:s): {}".format(elapsed))
示例#2
0
            d_loss.backward()
            d_optim.step()
            d_log.append(d_loss.item())

            # Train G

            z = torch.randn(batch_size, latent_dim).to(device)
            fake_img = G(z, onehot_class)

            fake_score = D(fake_img, onehot_class)

            g_loss = criterion(fake_score, real_labels)

            g_optim.zero_grad()
            g_loss.backward()
            g_optim.step()
            g_log.append(g_loss.item())

            utils.show_process(epoch_i, step_i + 1, step_per_epoch, g_log,
                               d_log)

        if epoch_i == 1:
            torchvision.utils.save_image(real_img.reshape(-1, 1, 32, 32),
                                         os.path.join(sample_dir, 'real.png'),
                                         nrow=10)

        utils.save_model(G, g_optim, g_log, checkpoint_dir, 'G.ckpt')
        utils.save_model(D, d_optim, d_log, checkpoint_dir, 'D.ckpt')
        cgan_utils.generate_classes(G, latent_dim, device, 10, epoch_i,
                                    sample_dir)
示例#3
0
                result = real_img

            code = net.encoder(real_img)

            reg_loss = 0
            if reg:
                reg_loss = ae_utils.L2_reg(code)

            reconstructed = net.decoder(code)

            reconstruction_loss = criterion(reconstructed, real_img)
            loss = reg_loss + reconstruction_loss

            optim.zero_grad()
            loss.backward()
            optim.step()
            loss_log.append(loss.item())

            ae_utils.show_process(epoch_i, step_i + 1, step_per_epoch,
                                  loss_log)

        if epoch_i == 1:
            torchvision.utils.save_image(result,
                                         os.path.join(sample_dir, 'orig.png'),
                                         nrow=10)
        reconstructed = net(result)
        utils.save_image(reconstructed, 10, epoch_i, step_i + 1, sample_dir)

        utils.save_model(net, optim, loss_log, checkpoint_dir,
                         'autoencoder.ckpt')
示例#4
0
            d_optim.zero_grad()
            loss.backward()
            d_optim.step()
            d_log.append(loss.item())

            code = AE.encoder(img)
            fake_score = D(code)
            loss = discrim_criterion(fake_score, real_label)

            ae_optim.zero_grad()
            loss.backward()
            ae_optim.step()

            utils.show_process(epoch_i, step_i + 1, step_per_epoch, rec_log,
                               d_log)

        if epoch_i == 1:
            torchvision.utils.save_image(result.reshape(-1, 1, 28, 28),
                                         os.path.join(sample_dir, 'orig.png'),
                                         nrow=10)
        reconstructed = AE(result)
        utils.save_image(reconstructed.reshape(-1, 1, 28, 28), 10, epoch_i,
                         step_i + 1, sample_dir)

        utils.save_model(AE, ae_optim, rec_log, checkpoint_dir, 'AE.ckpt')
        utils.save_model(D, d_optim, d_log, checkpoint_dir, 'D.ckpt')

    ae_utils.plot_manifold(AE.encoder, device, dataloader.dataset,
                           dataloader.dataset.__len__(), sample_dir)