Example #1
0
class Extractor(object):
    def __init__(self,
                 e_model,
                 batch_size=128,
                 cat_info=True,
                 vis=False,
                 dataloader=False):
        self.batch_size = batch_size
        self.cat_info = cat_info

        self.model = e_model

        if dataloader:
            self.dataloader = dataloader
        else:
            self.transform = tv.transforms.Compose([
                tv.transforms.Resize(224),
                tv.transforms.ToTensor(),
                tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])
        self.vis = vis
        if self.vis:
            self.viser = Visualizer('caffe2torch_test')

    # extract the inputs' feature via self.model
    # the model's output only contains the inputs' feature
    @t.no_grad()
    def extract(self, data_root, out_root=None):
        if self.dataloader:
            return self._extract_with_dataloader(data_root=data_root,
                                                 cat_info=self.cat_info,
                                                 out_root=out_root)
        else:
            return self._extract_without_dataloader(data_root=data_root,
                                                    cat_info=self.cat_info,
                                                    out_root=out_root)

    # extract the inputs' feature via self.model
    # the model's output contains both the inputs' feature and category info
    @t.no_grad()
    def _extract_without_dataloader(self, data_root, cat_info, out_root):
        feature = []
        name = []

        self.model.eval()

        cnames = sorted(os.listdir(data_root))

        for cname in cnames:
            c_path = os.path.join(data_root, cname)
            if os.path.isdir(c_path):
                fnames = sorted(os.listdir(c_path))
                for fname in fnames:
                    path = os.path.join(c_path, fname)

                    image = Image.open(path)
                    image = self.transform(image)
                    image = image[None]
                    image = image.cuda()

                    if self.vis:
                        self.viser.images(image.cpu().numpy() * 0.5 + 0.5,
                                          win='extractor')
                    out = self.model(image)
                    if cat_info:
                        i_feature = out[1]
                    else:
                        i_feature = out

                    feature.append(i_feature.cpu().squeeze().numpy())
                    name.append(cname + '/' + fname)

        data = {'name': name, 'feature': feature}
        if out_root:
            out = open(out_root, 'wb')
            pickle.dump(data, out)

            out.close()

        return data

    # extract the inputs' feature via self.model
    # the model's output contains both the inputs' feature and category info
    # the input is loaded by dataloader
    @t.no_grad()
    def _extract_with_dataloader(self, data_root, cat_info, out_root):
        names = []

        self.model.eval()

        opt = Config()
        opt.image_root = data_root
        opt.batch_size = 128

        dataloader = ImageDataLoader(opt)
        dataset = dataloader.load_data()

        for i, data in enumerate(dataset):
            image = data['I'].cuda()
            name = data['N']

            out = self.model(image)
            if cat_info:
                i_feature = out[1]
            else:
                i_feature = out
            if i == 0:
                feature = i_feature.cpu().squeeze().numpy()

            else:
                feature = np.append(feature,
                                    i_feature.cpu().squeeze().numpy(),
                                    axis=0)

            names += name

        data = {'name': names, 'feature': feature}
        if out_root:
            out = open(out_root, 'wb')
            pickle.dump(data, out)

            out.close()

        return data

    # reload model with model file
    # the reloaded model contains fully connection layer
    def reload_state_dict_with_fc(self, state_file):
        temp_model = tv.models.resnet34(pretrained=False)
        temp_model.fc = nn.Linear(512, 125)
        temp_model.load_state_dict(t.load(state_file))

        pretrained_dict = temp_model.state_dict()

        model_dict = self.model.state_dict()

        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }

        model_dict.update(pretrained_dict)
        self.model.load_state_dict(model_dict)

    # reload model with model file
    # the reloaded model doesn't contain fully connection layer
    def reload_state_dic(self, state_file):
        self.model.load_state_dict(t.load(state_file))

    # reload model with model object directly
    def reload_model(self, model):
        self.model = model
Example #2
0
File: main.py Project: lucineIT/GAN
def train(**kwargs):
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)
    if opt.vis:
        from utils.visualize import Visualizer
        vis = Visualizer(opt.env)

    transforms = tv.transforms.Compose([
        tv.transforms.Scale(opt.image_size),
        tv.transforms.CenterCrop(opt.image_size),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    dataset = tv.datasets.ImageFolder(opt.data_path, transform=transforms)
    dataloader = t.utils.data.DataLoader(dataset,
                                         batch_size=opt.batch_size,
                                         shuffle=True,
                                         num_workers=opt.num_workers,
                                         drop_last=True)

    # 定义网络
    netg, netd = NetGenerator(opt), NetD(opt)
    map_location = lambda storage, loc: storage
    if opt.netd_path:
        netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
    if opt.netg_path:
        netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))

    # 定义优化器和损失
    optimizer_g = t.optim.Adam(netg.parameters(),
                               opt.G_lr,
                               betas=(opt.beta1, 0.999))
    optimizer_d = t.optim.Adam(netd.parameters(),
                               opt.D_lr,
                               betas=(opt.beta1, 0.999))
    criterion = t.nn.BCELoss()

    # 真图片label为1,假图片label为0
    # noises为生成网络的输入
    true_labels = Variable(t.ones(opt.batch_size))
    fake_labels = Variable(t.zeros(opt.batch_size))
    fix_noises = Variable(t.randn(opt.batch_size, opt.nz, 1, 1))
    noises = Variable(t.randn(opt.batch_size, opt.nz, 1, 1))

    errord_meter = AverageValueMeter()
    errorg_meter = AverageValueMeter()

    if opt.use_gpu:
        netd.cuda()
        netg.cuda()
        criterion.cuda()
        true_labels, fake_labels = true_labels.cuda(), fake_labels.cuda()
        fix_noises, noises = fix_noises.cuda(), noises.cuda()

    epochs = range(opt.max_epoch)
    for epoch in iter(epochs):
        for ii, (img, _) in tqdm.tqdm(enumerate(dataloader)):
            real_img = Variable(img)
            if opt.use_gpu:
                real_img = real_img.cuda()
            if ii % opt.d_every == 0:
                # 训练判别器
                optimizer_d.zero_grad()
                ## 尽可能的把真图片判别为正确
                output = netd(real_img)
                error_d_real = criterion(output, true_labels)
                error_d_real.backward()

                ## 尽可能把假图片判别为错误
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises).detach()  # 根据噪声生成假图
                output = netd(fake_img)
                error_d_fake = criterion(output, fake_labels)
                error_d_fake.backward()
                optimizer_d.step()

                error_d = error_d_fake + error_d_real

                errord_meter.add(error_d.data[0])

            if ii % opt.g_every == 0:
                # 训练生成器
                optimizer_g.zero_grad()
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises)
                output = netd(fake_img)
                error_g = criterion(output, true_labels)
                error_g.backward()
                optimizer_g.step()
                errorg_meter.add(error_g.data[0])

            if opt.vis and ii % opt.plot_every == opt.plot_every - 1:
                ## 可视化
                if os.path.exists(opt.debug_file):
                    ipdb.set_trace()
                fix_fake_imgs = netg(fix_noises)
                vis.images(fix_fake_imgs.data.cpu().numpy()[:64] * 0.5 + 0.5,
                           win='fixfake')
                vis.plot('error_d', errord_meter.value()[0])
                vis.images(real_img.data.cpu().numpy()[:64] * 0.5 + 0.5,
                           win='real')
                vis.plot('error_g', errorg_meter.value()[0])

        if epoch % opt.decay_every == 0:
            # 保存模型、图片
            tv.utils.save_image(fix_fake_imgs.data[:64],
                                '%s/%s.png' % (opt.save_path,
                                               (epoch + opt.startpoint)),
                                normalize=True,
                                range=(-1, 1))
            t.save(netd.state_dict(),
                   'checkpoints/netd_%s.pth' % (epoch + opt.startpoint))
            t.save(netg.state_dict(),
                   'checkpoints/netg_%s.pth' % (epoch + opt.startpoint))
            errord_meter.reset()
            errorg_meter.reset()
            optimizer_g = t.optim.Adam(netg.parameters(),
                                       opt.G_lr,
                                       betas=(opt.beta1, 0.999))
            optimizer_d = t.optim.Adam(netd.parameters(),
                                       opt.D_lr,
                                       betas=(opt.beta1, 0.999))
def train(**kwargs):
    opt._parse(kwargs)
    vis = Visualizer(opt.env, port=opt.vis_port)

    # step1: configure model
    model = getattr(models, opt.model)()
    if opt.load_model_path:
        model.load_new(opt.load_model_path)
    else:
        print('Initialize the model!')
        model.apply(weight_init)

    model.to(opt.device)

    # step2: data
    train_data = TextData(opt.data_root, opt.train_txt_path)
    val_data = TextData(opt.data_root, opt.val_txt_path)
    train_dataloader = DataLoader(train_data,
                                  opt.batch_size,
                                  shuffle=True,
                                  num_workers=opt.num_workers)
    val_dataloader = DataLoader(val_data,
                                opt.batch_size,
                                shuffle=False,
                                num_workers=opt.num_workers)

    # step3: criterion and optimizer
    criterion = t.nn.CrossEntropyLoss()
    lr = opt.lr
    optimizer = model.get_optimizer(lr, opt.weight_decay)

    # step4: meters
    loss_meter = meter.AverageValueMeter()
    confusion_matrix = meter.ConfusionMeter(2)
    previous_loss = 1e10

    # train
    for epoch in range(opt.max_epoch):

        loss_meter.reset()
        confusion_matrix.reset()

        for ii, (data, label) in tqdm(enumerate(train_dataloader)):
            # train model
            input = data.to(opt.device)
            target = label.to(opt.device)
            optimizer.zero_grad()
            score = model(input)
            loss = criterion(score, target)
            loss.backward()
            #for n, p in model.named_parameters():
            #    print(n)
            #    h = p.register_hook(lambda grad: print(grad))
            optimizer.step()

            # meters update and visualize
            loss_meter.add(loss.item())
            confusion_matrix.add(score.data, target.data)
            if ii % opt.print_freq == 0:
                vis.plot('loss', loss_meter.value()[0])

                # enter debug mode
                if os.path.exists(opt.debug_file):
                    ipdb.set_trace()
            if ii % (opt.print_freq * 10) == 0:
                vis.images(input.cpu().numpy(),
                           opts=dict(title='Label', caption='Label'),
                           win=1)
                print('Epoch: {} Iter: {} Loss: {}'.format(epoch, ii, loss))

        if epoch % 2 == 0:
            model.save('./checkpoints/' + opt.env + '_' + str(epoch) + '.pth')

        # validate and visualize
        val_cm, val_accuracy = val(model, val_dataloader)

        vis.plot('val_accuracy', val_accuracy)
        vis.log(
            "epoch:{epoch},lr:{lr},loss:{loss},train_cm:{train_cm},val_cm:{val_cm}"
            .format(epoch=epoch,
                    loss=loss_meter.value()[0],
                    val_cm=str(val_cm.value()),
                    train_cm=str(confusion_matrix.value()),
                    lr=lr))
        train_cm = confusion_matrix.value()
        t_accuracy = 100. * (train_cm[0][0] +
                             train_cm[1][1]) / (train_cm.sum())
        vis.plot('train_accuracy', t_accuracy)
        if loss_meter.value()[0] > previous_loss:
            lr = lr * opt.lr_decay
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

        previous_loss = loss_meter.value()[0]
Example #4
0
def train(**kwargs):
    #init
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)

    if opt.vis:
        vis = Visualizer(opt.env)
        vis_val = Visualizer('valdemoire')

    #dataset
    FiveCrop_transforms = transforms.Compose([
        transforms.FiveCrop(256),
        transforms.Lambda(lambda crops: torch.stack(
            [transforms.ToTensor()(crop) for crop in crops]))
    ])
    data_transforms = transforms.Compose([
        # transforms.RandomCrop(256),
        transforms.ToTensor()
    ])
    train_data = MoireData(opt.train_path)
    test_data = MoireData(opt.test_path, is_val=True)
    train_dataloader = DataLoader(train_data,
                                  batch_size=opt.train_batch_size,
                                  shuffle=True,
                                  num_workers=opt.num_workers,
                                  drop_last=True)
    test_dataloader = DataLoader(test_data,
                                 batch_size=opt.val_batch_size,
                                 shuffle=True,
                                 num_workers=opt.num_workers,
                                 drop_last=True)

    last_epoch = 0
    #model_init
    cfg.merge_from_file("config/cfg.yaml")
    model = get_pose_net(cfg, pretrained=opt.model_path)  #initweight
    model = model.to(opt.device)

    if opt.vis:
        val_loss, val_psnr = val(model, test_dataloader, vis_val)
        print(val_loss, val_psnr)
    else:
        val_loss, val_psnr = val(model, test_dataloader)
        print(val_loss, val_psnr)

    criterion_c = L1_Charbonnier_loss()
    criterion_s = L1_Sobel_Loss()
    lr = opt.lr
    optimizer = torch.optim.Adam(
        params=model.parameters(),
        lr=lr,
        weight_decay=0.01  #0.005
    )

    if opt.model_path:
        map_location = lambda storage, loc: storage
        checkpoint = torch.load(opt.model_path, map_location=map_location)
        last_epoch = checkpoint["epoch"]
        optimizer_state = checkpoint["optimizer"]
        optimizer.load_state_dict(optimizer_state)

        lr = checkpoint["lr"]
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

    loss_meter = meter.AverageValueMeter()
    psnr_meter = meter.AverageValueMeter()
    previous_loss = 1e100
    accumulation_steps = opt.accumulation_steps

    for epoch in range(opt.max_epoch):
        if epoch < last_epoch:
            continue
        loss_meter.reset()
        psnr_meter.reset()
        torch.cuda.empty_cache()
        loss_list = []

        for ii, (moires, clear_list) in tqdm(enumerate(train_dataloader)):
            moires = moires.to(opt.device)
            clears = clear_list[0].to(opt.device)

            output_list, edge_output_list = model(moires)
            outputs, edge_X = output_list[0], edge_output_list[0]

            if epoch < 20:
                pass
            elif epoch >= 20 and epoch < 40:
                opt.loss_alpha = 0.9
            else:
                opt.loss_alpha = 1.0

            c_loss = criterion_c(outputs, clears)
            s_loss = criterion_s(edge_X, clears)
            loss = opt.loss_alpha * c_loss + (1 - opt.loss_alpha) * s_loss

            # saocaozuo gradient accumulation
            loss = loss / accumulation_steps
            loss.backward()

            if (ii + 1) % accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()

            loss_meter.add(loss.item() * accumulation_steps)

            moires = tensor2im(moires)
            outputs = tensor2im(outputs)
            clears = tensor2im(clears)

            psnr = colour.utilities.metric_psnr(outputs, clears)
            psnr_meter.add(psnr)

            if opt.vis and (ii + 1) % opt.plot_every == 0:  #100个batch画图一次
                vis.images(moires, win='moire_image')
                vis.images(outputs, win='output_image')
                vis.text(
                    "current outputs_size:{outputs_size},<br/> outputs:{outputs}<br/>"
                    .format(outputs_size=outputs.shape, outputs=outputs),
                    win="size")
                vis.images(clears, win='clear_image')
                #record the train loss to txt
                vis.plot('train_loss',
                         loss_meter.value()
                         [0])  #meter.value() return 2 value of mean and std
                vis.log(
                    "epoch:{epoch}, lr:{lr}, train_loss:{loss}, train_psnr:{train_psnr}"
                    .format(epoch=epoch + 1,
                            loss=loss_meter.value()[0],
                            lr=lr,
                            train_psnr=psnr_meter.value()[0]))
                loss_list.append(str(loss_meter.value()[0]))

            torch.cuda.empty_cache()
        if opt.vis:
            val_loss, val_psnr = val(model, test_dataloader, vis_val)
            vis.plot('val_loss', val_loss)
            vis.log(
                "epoch:{epoch}, average val_loss:{val_loss}, average val_psnr:{val_psnr}"
                .format(epoch=epoch + 1, val_loss=val_loss, val_psnr=val_psnr))
        else:
            val_loss, val_psnr = val(model, test_dataloader)

        #每个epoch把loss写入文件
        with open(opt.save_prefix + "loss_list.txt", 'a') as f:
            f.write("\nepoch_{}\n".format(epoch + 1))
            f.write('\n'.join(loss_list))

        if (epoch + 1) % opt.save_every == 0 or epoch == 0:  # 每5个epoch保存一次
            prefix = opt.save_prefix + 'HRnet_epoch{}_'.format(epoch + 1)
            file_name = time.strftime(prefix + '%m%d_%H_%M_%S.pth')
            checkpoint = {
                'epoch': epoch + 1,
                "optimizer": optimizer.state_dict(),
                "model": model.state_dict(),
                "lr": lr
            }
            torch.save(checkpoint, file_name)

        if (loss_meter.value()[0] > previous_loss) or ((epoch + 1) % 10) == 0:
            lr = lr * opt.lr_decay
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
        previous_loss = loss_meter.value()[0]

    prefix = opt.save_prefix + 'HRnet_final_'
    file_name = time.strftime(prefix + '%m%d_%H_%M_%S.pth')
    checkpoint = {
        'epoch': epoch + 1,
        "optimizer": optimizer.state_dict(),
        "model": model.state_dict(),
        "lr": lr
    }
    torch.save(checkpoint, file_name)
Example #5
0
def train(**kwargs):
    opt._parse(kwargs)
    if opt.vis:
        from utils.visualize import Visualizer
        vis = Visualizer(opt.env)

    # 数据
    transforms = tv.transforms.Compose([
        tv.transforms.Resize(opt.image_size),
        tv.transforms.CenterCrop(opt.image_size),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    dataset = tv.datasets.ImageFolder(opt.data_path, transform=transforms)
    dataloader = t.utils.data.DataLoader(
        dataset,
        batch_size=opt.batch_size,
        shuffle=True,
        num_workers=opt.num_workers,
        drop_last=True  # 最后一个数据集不满batch_size  将被遗弃
    )

    # 网络
    netg, netd = NetG(opt), NetD(opt)
    map_location = lambda storage, loc: storage
    if opt.netd_path:
        netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
    if opt.netg_path:
        netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))
    netd.to(opt.device)
    netg.to(opt.device)

    # 定义优化器和损失
    optimizer_g = t.optim.Adam(netg.parameters(),
                               opt.lr1,
                               betas=(opt.beta1, 0.999))
    optimizer_d = t.optim.Adam(netd.parameters(),
                               opt.lr2,
                               betas=(opt.beta1, 0.999))
    criterion = t.nn.BCELoss().to(opt.device)

    # 真图片label为1,假图片label为0
    # noises为生成网络的输入
    true_labels = t.ones(opt.batch_size).to(opt.device)  # 真
    fake_labels = t.zeros(opt.batch_size).to(opt.device)  # 假
    fix_noises = t.randn(opt.batch_size, opt.nz, 1, 1).to(opt.device)
    noises = t.randn(opt.batch_size, opt.nz, 1, 1).to(opt.device)

    errord_meter = AverageValueMeter()
    errorg_meter = AverageValueMeter()

    epochs = range(opt.max_epoch)
    for epoch in iter(epochs):
        for ii, (img, _) in tqdm.tqdm(enumerate(dataloader)):
            real_img = img.to(opt.device)

            if ii % opt.d_every == 0:
                # 训练判别器
                optimizer_d.zero_grad()
                ## 尽可能的把真图片判别为正确
                output = netd(real_img)
                error_d_real = criterion(output, true_labels)
                error_d_real.backward()

                ## 尽可能把假图片判别为错误
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises).detach()  # 根据噪声生成假图
                output = netd(fake_img)
                error_d_fake = criterion(output, fake_labels)
                error_d_fake.backward()
                optimizer_d.step()

                error_d = error_d_fake + error_d_real

                errord_meter.add(error_d.item())

            if ii % opt.g_every == 0:
                # 训练生成器
                optimizer_g.zero_grad()
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises)
                output = netd(fake_img)
                error_g = criterion(output, true_labels)
                error_g.backward()
                optimizer_g.step()
                errorg_meter.add(error_g.item())

            if opt.vis and ii % opt.plot_every == opt.plot_every - 1:
                ## 可视化
                fix_fake_imgs = netg(fix_noises)
                vis.images(fix_fake_imgs.detach().cpu().numpy()[:64] * 0.5 +
                           0.5,
                           win='fixfake')
                vis.images(real_img.data.cpu().numpy()[:64] * 0.5 + 0.5,
                           win='real')
                vis.plot('errord', errord_meter.value()[0])
                vis.plot('errorg', errorg_meter.value()[0])

        if (epoch + 1) % opt.save_every == 0:
            # 保存模型、图片
            tv.utils.save_image(fix_fake_imgs.data[:64],
                                '%s/%s.png' % (opt.save_path, epoch),
                                normalize=True,
                                range=(-1, 1))
            tag = [
                i for i in os.listdir('./data') if os.path.isdir('./data/' + i)
            ][0]
            t.save(netd.state_dict(), 'checkpoints/%s_d_%s.pth' % (tag, epoch))
            t.save(netg.state_dict(), 'checkpoints/%s_g_%s.pth' % (tag, epoch))
            errord_meter.reset()
            errorg_meter.reset()
Example #6
0
def train(opt):
    model_G = getattr(model, opt.G_model)(opt)
    model_D = getattr(model, opt.D_model)(opt)

    vis = Visualizer(opt.env)

    if opt.load_model_path:
        pass

    train_dataloder = dataloader(opt)

    criterion = torch.nn.BCELoss()
    lr_g = opt.lr_g
    lr_d = opt.lr_d
    optimizer_g = torch.optim.Adam(model_G.parameters(),
                                   lr_g,
                                   betas=(opt.beta1, 0.999))
    optimizer_d = torch.optim.Adam(model_D.parameters(),
                                   lr_d,
                                   betas=(opt.beta1, 0.999))

    # label
    true_labels = torch.ones(opt.batch_size)
    fake_labels = torch.zeros(opt.batch_size)
    fix_noises = Variable(torch.randn((opt.batch_size, opt.nz, 1, 1)))
    noises = Variable(torch.randn((opt.batch_size, opt.nz, 1, 1)))

    # meter
    loss_G_meter = meter.AverageValueMeter()
    loss_D_meter = meter.AverageValueMeter()

    if opt.use_gpu:
        model_G.cuda()
        model_D.cuda()
        criterion.cuda()
        true_labels, fake_labels = true_labels.cuda(), fake_labels.cuda()
        fix_noises, noises = fix_noises.cuda(), noises.cuda()

    for epoch in range(opt.max_epoch):
        loss_G_meter.reset()
        loss_D_meter.reset()

        for ii, (real_img, _) in tqdm(enumerate(train_dataloder)):
            if opt.use_gpu:
                real_img = real_img.cuda()

            if ii % opt.d_every == 0:
                # train distinguish network
                optimizer_d.zero_grad()
                # train by real image
                output = model_D(real_img)
                loss_d_real = criterion(output, true_labels)
                loss_d_real.backward()

                # train by fake image
                # refresh the value of noises
                noises.data.copy_(torch.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = model_G(noises).detach()
                output = model_D(fake_img)
                loss_d_fake = criterion(output, fake_labels)
                loss_d_fake.backward()

                optimizer_d.step()
                loss_d = loss_d_fake + loss_d_real
                loss_D_meter.add(loss_d.item())

            if ii % opt.g_every == 0:
                # train generate network
                optimizer_g.zero_grad()
                # train by fake image
                # refresh the value of noises
                noises.data.copy_(torch.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = model_G(noises)
                output = model_D(fake_img)

                loss_g = criterion(output, true_labels)
                loss_g.backward()

                optimizer_g.step()
                loss_G_meter.add(loss_g.item())

            if ii % opt.print_freq:
                vis.plot('loss_d', loss_D_meter.value()[0])
                vis.plot('loss_g', loss_G_meter.value()[0])
                fix_fake_img = model_G(fix_noises)
                vis.images(fix_fake_img.data.cpu().numpy()[:64] * 0.5 + 0.5,
                           win='fixfake')

        if (epoch + 1) % 20 == 0:
            model_G.save(opt.save_model_path + opt.G_model + '_' + str(epoch))
            model_D.save(opt.save_model_path + opt.D_model + '_' + str(epoch))
Example #7
0
        retrievaled_name = retrievaled_name.split('/')[1]
        retrievaled_name = retrievaled_name.split('.')[0]

        if retrievaled_class == query_class:
            print(ii, 'correct class', query_name, retrievaled_name)
            if query_img.find(retrievaled_name) != -1:
                print(ii, 'correct item', query_name)
                count += 1

    if ii == 0:
        result = query_image
    else:
        result = np.append(result, query_image, axis=0)

result = transform(result)
vis.images(result.numpy(), win='result')

print(count)

count = 0
count_5 = 0
K = 5

div = 0

for ii, (query_sketch,
         query_name) in tqdm.tqdm(enumerate(zip(sketch_feature, sketch_name))):
    query_sketch = np.reshape(query_sketch, [1, np.shape(query_sketch)[0]])

    query_split = query_name.split('/')
    query_class = query_split[0]