Exemplo n.º 1
0
def val(val_dataloader, network, save=False):
    # network.eval()
    dice_meter_b = AverageValueMeter()
    dice_meter_f = AverageValueMeter()

    dice_meter_b.reset()
    dice_meter_f.reset()

    images = []
    with torch.no_grad():
        for i, (image, mask, _, _) in enumerate(val_dataloader):
            if mask.sum() == 0:
                continue
            image, mask = image.to(device), mask.to(device)

            proba = F.softmax(network(image), dim=1)
            predicted_mask = proba.max(1)[1]
            iou = dice_loss(predicted_mask, mask)

            dice_meter_f.add(iou[1])
            dice_meter_b.add(iou[0])

            if save:
                images = save_images(images, image, mask, proba[:, 1],
                                     predicted_mask)
    if save:
        grid = make_grid(images, nrow=4)
        return [[dice_meter_b.value()[0], dice_meter_f.value()[0]], grid]
    else:
        return [[dice_meter_b.value()[0], dice_meter_f.value()[0]], None]
Exemplo n.º 2
0
def train():
    vis = Visualizer(server='http://turing.livia.etsmtl.ca', env='EEG')
    data_root = '/home/AN96120/python_project/Seizure Prediction/processed_data/fft_meanlog_std_lowcut0.1highcut180nfreq_bands12win_length_sec60stride_sec60/Dog_1'
    dataloader_train = get_dataloader(data_root, training=True)
    dataloader_test = get_dataloader(data_root, training=False)
    # No interaction has been found in the training and testing dataset.
    weights = t.Tensor([1/(np.array(dataloader_train.dataset.targets)==0).mean(),1/(np.array(dataloader_train.dataset.targets)==1).mean()  ])
    criterion = nn.CrossEntropyLoss(weight=weights.cuda())

    net = convNet ()
    net.cuda()

    optimiser = t.optim.Adam(net.parameters(),lr= 1e-4,weight_decay=1e-4)
    loss_avg = AverageValueMeter()
    epochs = 10000
    for epoch in range(epochs):
        loss_avg.reset()
        for ii, (data, targets) in enumerate(dataloader_train):
            data, targets= data.type(t.FloatTensor), targets.type(t.LongTensor)
            data = data.cuda()
            targets = targets.cuda()
            optimiser.zero_grad()
            output = net(data)
            loss = criterion(output,targets)
            loss_avg.add(loss.item())
            loss.backward()
            optimiser.step()
        vis.plot('loss',loss_avg.value()[0])

        _,auc_train=val(dataloader_train,net)
        _, auc_test =val(dataloader_test,net)
        print(auc_train,auc_test)
def val(net, dataloader_):
    global highest_iou
    net.eval()
    iou_meter_val = AverageValueMeter()
    loss_meter_val = AverageValueMeter()
    iou_meter_val.reset()
    for i, (img, mask, _) in tqdm(enumerate(dataloader_)):
        (img, mask) = (img.cuda(), mask.cuda()) if (torch.cuda.is_available()
                                                    and use_cuda) else (img,
                                                                        mask)
        pred_val = net(img)
        loss_val = criterion(pred_val, mask.squeeze(1))
        loss_meter_val.add(loss_val.item())
        iou_val = iou_loss(pred2segmentation(pred_val),
                           mask.squeeze(1).float(), class_number)[1]
        iou_meter_val.add(iou_val)
        if i % val_print_frequncy == 0:
            showImages(board_val_image, img, mask, pred2segmentation(pred_val))

    board_loss.plot('val_iou_per_epoch', iou_meter_val.value()[0])
    board_loss.plot('val_loss_per_epoch', loss_meter_val.value()[0])
    net.train()
    if highest_iou < iou_meter_val.value()[0]:
        highest_iou = iou_meter_val.value()[0]
        torch.save(
            net.state_dict(), 'checkpoint/modified_ENet_%.3f_%s.pth' %
            (iou_meter_val.value()[0], 'equal_' + str(Equalize)))
        print('The highest IOU is:%.3f' % iou_meter_val.value()[0],
              'Model saved.')
Exemplo n.º 4
0
def val():
    global highest_dice_loss
    dice_loss_meter = AverageValueMeter()
    dice_loss_meter.reset()
    for i, (img, mask, weak_mask, _) in enumerate(val_loader):
        if (weak_mask.sum() <= 3) or (mask.sum() <= 10):
            # print('No mask has been found')
            continue
        if not ((list(img.shape[-2:]) == list(mask.shape[-2:])) and (
                list(img.shape[-2:]) == list(weak_mask.shape[-2:]))):
            continue
        img, mask, weak_mask = img.cuda(), mask.cuda(), weak_mask.cuda()

        predict_ = F.softmax(net(img), dim=1)
        segm = pred2segmentation(predict_)
        diceloss_F = dice_loss(segm, mask)
        diceloss_B = dice_loss(1 - segm, 1 - mask)
        dice_loss_meter.add((diceloss_F + diceloss_B).item() / 2)

        if i % 100 == 0:
            board_val_image.image(img[0], 'medical image')
            board_val_image.image(color_transform(weak_mask[0]), 'weak_mask')
            board_val_image.image(color_transform(segm[0]), 'prediction')
    board_loss.plot('dice_loss for validationset', dice_loss_meter.value()[0])

    if dice_loss_meter.value()[0] > highest_dice_loss:
        highest_dice_loss = dice_loss_meter.value()[0]
        torch.save(net.state_dict(), 'Enet_Square_barrier.pth')
        print('saved with dice:%f' % highest_dice_loss)
Exemplo n.º 5
0
def pretrain(dataloader, network, path=None):
    class config:
        lr = 1e-3
        epochs = 100
        path = '../checkpoint/pretrained_net.pth'

    pretrain_config = config()
    if path:
        pretrain_config.path = path
    network.to(device)
    criterion_ = CrossEntropyLoss2d()
    optimiser_ = torch.optim.Adam(network.parameters(), pretrain_config.lr)
    loss_meter = AverageValueMeter()
    for i in range(pretrain_config.epochs):
        loss_meter.reset()

        for i, (img, mask, weak_mask, _) in tqdm(enumerate(dataloader)):
            img, mask = img.to(device), mask.to(device)
            optimiser_.zero_grad()
            output = network(img)
            loss = criterion_(output, mask.squeeze(1))
            loss.backward()
            optimiser_.step()
            loss_meter.add(loss.item())

        # import ipdb
        # ipdb.set_trace()
        print(loss_meter.value()[0])
        torch.save(network.state_dict(), pretrain_config.path)
        # torch.save(network.parameters(),path)
        print('pretrained model saved.')
Exemplo n.º 6
0
def val(dataloader, net):
    avg_acc=AverageValueMeter()
    avg_acc.reset()
    y_true =[]
    y_predict=[]
    y_predict_proba=[]
    net.eval()
    with t.no_grad():
        for i,(data,target) in enumerate(dataloader):
            data=data.type(t.FloatTensor)
            data = data.cuda()
            target = target.cuda()
            output = net(data)
            decision = output.max(1)[1]
            y_predict.extend(decision.cpu().numpy().tolist())
            proba = F.softmax(output,dim=1)[:,1]
            y_predict_proba.extend(proba.cpu().numpy().tolist())
            y_true.extend(target.cpu().numpy().tolist())
            acc = (decision==target).sum().item()/np.float(len(target))
            avg_acc.add(acc)
    avg_auc = roc_auc_score(y_true,y_predict_proba)

    cnf_matrix = confusion_matrix(y_true, y_predict)
    np.set_printoptions(precision=2)
    # print(avg_auc)
    net.train()
    return avg_acc.value()[0],avg_auc
Exemplo n.º 7
0
class Trainer:
    def __init__(self, args, model):
        self.name = args.name
        self.model = model
        self.l1win = None
        self.l2win = None
        self.l1meter = AverageValueMeter()
        self.l2meter = AverageValueMeter()
        self.visdom = Visdom(
            port=args.vis_port) if args.vis_steps > 0 else None

    @property
    def mode(self):
        return 'training' if self.model.training else 'testing'

    @property
    def losses(self):
        return self.l1meter.value()[0], self.l2meter.value()[1]

    def reset(self):
        self.l1meter.reset()
        self.l2meter.reset()

    def log_losses(self, epoch, step):
        l1, l2 = self.losses
        message = f'{self.name} is {self.mode} (epoch: {epoch}, step: {step}) '
        message += f'l1 average: {l1}, l2 average: {l2}'
        print(message)

    def vis_losses(self, epoch):
        l1, l2 = self.losses
        x, y1, y2 = np.array([epoch]), np.array([l1]), np.array([l2])

        if self.l1win is None or self.l2win is None:
            opt = dict(xlabel='epochs',
                       xtickstep=1,
                       ylabel='mean loss',
                       width=900)
            self.l1win = self.visdom.line(X=x,
                                          Y=y1,
                                          opts=dict(
                                              title=f'l1 loss ({self.name})',
                                              **opt))
            self.l2win = self.visdom.line(X=x,
                                          Y=y2,
                                          opts=dict(
                                              title=f'l2 loss ({self.name})',
                                              **opt))
        else:
            n = '1' if self.model.training else '2'
            self.visdom.updateTrace(X=x, Y=y1, win=self.l1win, name=n)
            self.visdom.updateTrace(X=x, Y=y2, win=self.l2win, name=n)

    def vis_images(self, epoch, step, images):
        title = f'({self.name}, epoch: {epoch}, step: {step})'
        for key, image in images.items():
            self.visdom.image(image.cpu().data,
                              env=self.mode,
                              opts=dict(title=f'{key} {title}'))
Exemplo n.º 8
0
def val(val_dataloader, network):
    network.eval()
    dice_meter = AverageValueMeter()
    dice_meter.reset()
    for i, (image, mask, _, _) in enumerate(val_dataloader):
        image, mask = image.to(device), mask.to(device)
        proba = F.softmax(network(image), dim=1)
        predicted_mask = proba.max(1)[1]
        iou = dice_loss(predicted_mask, mask).item()
        dice_meter.add(iou)
    print('val iou:  %.6f' % dice_meter.value()[0])
    return dice_meter.value()[0]
Exemplo n.º 9
0
def evaluate(net, dataloader, device):
    net.eval()
    dice_meter = AverageValueMeter()
    dice_meter.reset()
    with torch.no_grad():
        for i, (img, mask, path) in enumerate(dataloader):
            img, mask = img.to(device), mask.to(device)
            pred = net(img)
            pred_mask = pred2segmentation(pred)
            dice_meter.add(dice_loss(pred_mask, mask))

    net.train()
    return dice_meter.value()[0]
class CalculateLossCallback(TrainingCallback):
    def __init__(self, key):
        self.key = key
        self.average_value_meter = AverageValueMeter()

    def on_mode_begin(self, mode, log):
        self.average_value_meter.reset()
        log[self.key] = float('NaN')

    def on_batch_end(self, batch, log):
        batch_size = log['batch_size']
        self.average_value_meter.add(log['loss'] * batch_size, batch_size)
        log[self.key] = self.average_value_meter.value()[0]
Exemplo n.º 11
0
def train(mode='CL'):
    model = dccrn(mode)
    model.to(opt.device)

    train_data = THCHS30(phase='train')
    train_loader = DataLoader(train_data,
                              batch_size=opt.batch_size,
                              num_workers=opt.num_workers,
                              shuffle=True)

    optimizer = Adam(model.parameters(), lr=opt.lr)
    scheduler = MultiStepLR(optimizer,
                            milestones=[
                                int(opt.max_epoch * 0.5),
                                int(opt.max_epoch * 0.7),
                                int(opt.max_epoch * 0.9)
                            ],
                            gamma=opt.lr_decay)
    criterion = SISNRLoss()

    loss_meter = AverageValueMeter()

    for epoch in range(0, opt.max_epoch):
        loss_meter.reset()
        for i, (data, label) in enumerate(train_loader):
            data = data.to(opt.device)
            label = label.to(opt.device)

            spec, wav = model(data)

            optimizer.zero_grad()
            loss = criterion(wav, label)
            loss.backward()
            optimizer.step()

            loss_meter.add(loss.item())

            if (i + 1) % opt.verbose_inter == 0:
                print('epoch', epoch + 1, 'batch', i + 1, 'SI-SNR',
                      -loss_meter.value()[0])
        if (epoch + 1) % opt.save_inter == 0:
            print('save model at epoch {0} ...'.format(epoch + 1))
            save_path = os.path.join(
                opt.checkpoint_root,
                'DCCRN_{0}_{1}.pth'.format(mode, epoch + 1))
            torch.save(model.state_dict(), save_path)

        scheduler.step()

    save_path = os.path.join(opt.checkpoint_root, 'DCCRN_{0}.pth'.format(mode))
    torch.save(model.state_dict(), save_path)
Exemplo n.º 12
0
def val(dataloader,net):
    net.eval()
    acc = AverageValueMeter()
    acc.reset()
    for i, (img, label) in enumerate(dataloader):
        batch_size = len(label)
        images = Variable(img).cuda()
        labels = Variable(label.squeeze()).cuda()
        output = net(images)
        predictedLabel = torch.max(output,1)[1]
        acc_ = (predictedLabel==labels).sum().type(torch.FloatTensor)/batch_size
        acc.add(acc_.item())
    net.train()
    return acc.value()[0]
Exemplo n.º 13
0
def pretrain(train_dataloader, val_dataloader_, network, path=None, split_ratio=0.1):
    highest_iou = -1
    class config:
        lr = 1e-3
        epochs = 100
        path = 'checkpoint'


    pretrain_config = config()
    if path :
        pretrain_config.path = path
    network.to(device)
    criterion_ = CrossEntropyLoss2d()
    optimiser_ = torch.optim.Adam(network.parameters(),pretrain_config.lr)
    loss_meter = AverageValueMeter()
    fiou_tables = []

    for iteration in range(pretrain_config.epochs):
        loss_meter.reset()

        for i, (img,mask,weak_mask,_) in tqdm(enumerate(train_dataloader)):
            img,mask = img.to(device), mask.to(device)
            optimiser_.zero_grad()
            output = network(img)
            loss = criterion_(output,mask.squeeze(1))
            loss.backward()
            optimiser_.step()
            loss_meter.add(loss.item())
        print('train_loss: %.6f'%loss_meter.value()[0])

        if (iteration+1) %50 ==0:
            for param_group in optimiser_.param_groups:
                param_group['lr'] = param_group['lr'] * 0.5
                print('learning rate:', param_group['lr'])

        val_iou = val(val_dataloader_,network)
        fiou_tables.append(val_iou)
        if val_iou > highest_iou:
            highest_iou = val_iou
            torch.save(network.state_dict(),
                       os.path.join(pretrain_config.path, 'model_%.4f_split_%.3f.pth' % (val_iou, split_ratio)))
            print('pretrained model saved with %.4f.'%highest_iou)
    return fiou_tables
Exemplo n.º 14
0
def val(model, dataloader, criterion):

    model.eval()
    device = t.device('cuda') if opt.use_gpu else t.device('cpu')
    ncorrect = 0
    nsample = 0
    loss_meter = AverageValueMeter()
    loss_meter.reset()
    for ii, (data, label) in enumerate(dataloader):
        nsample += data.size()[0]
        feature = data.to(device)
        target = label.to(device)
        prob = model(feature)
        loss = criterion(prob, target)
        score = t.nn.functional.softmax(prob, dim=1)
        index = score.topk(1)[1].view(-1)
        loss_meter.add(loss.item())
        ncorrect += (index == target).cpu().sum().item()

    accu = float(ncorrect) / nsample * 100
    loss = loss_meter.value()[0]
    return accu, loss
Exemplo n.º 15
0
def train():
    totalloss_meter = AverageValueMeter()
    sizeloss_meter = AverageValueMeter()
    celoss_meter = AverageValueMeter()

    for epoch in range(max_epoch):
        totalloss_meter.reset()
        celoss_meter.reset()
        sizeloss_meter.reset()
        if epoch % 5 == 0:
            for param_group in optimiser.param_groups:
                param_group['lr'] = lr * (0.9 ** (epoch // 3))
                print('learning rate:', param_group['lr'])
            print('save model:')
            # torch.save(net.state_dict(), 'U_net_2Class.pth')

        for i, (img, mask, weak_mask, _) in tqdm(enumerate(train_loader)):
            if (weak_mask.sum() == 0) or (mask.sum() == 0):
                # print('No mask has been found')
                continue
            if not ((list(img.shape[-2:]) == list(mask.shape[-2:])) and (
                    list(img.shape[-2:]) == list(weak_mask.shape[-2:]))):
                continue
            img, mask, weak_mask = img.cuda(), mask.cuda(), weak_mask.cuda()
            optimiser.zero_grad()
            predict = net(img)
            loss_ce = partialCECriterion(predict, weak_mask.squeeze(1))
            # loss_ce = torch.Tensor([0]).cuda()
            # celoss_meter.add(loss_ce.item())
            loss_size = sizeCriterion(predict)
            # loss_size = torch.Tensor([0]).cuda()
            sizeloss_meter.add(loss_size.item())
            loss = loss_ce + loss_size
            totalloss_meter.add(loss.item())
            loss.backward()
            torch.nn.utils.clip_grad_norm_(net.parameters(), 1e-4)
            optimiser.step()
            if i % 50 == 0:
                predict_ = F.softmax(predict, dim=1)
                segm = pred2segmentation(predict)
                print("ce_loss:%.4f,  size_loss:%.4f, FB percentage:%.2f" % (loss_ce.item(), loss_size.item(), ((
                                                                                                                            predict_[
                                                                                                                            :,
                                                                                                                            1,
                                                                                                                            :,
                                                                                                                            :] * weak_mask.data.float()).sum() / weak_mask.data.float().sum()).item()))
                board_train_image.image(img[0], 'medical image')
                board_train_image.image(color_transform(mask[0]), 'weak_mask')
                board_train_image.image(color_transform(weak_mask[0]), 'weak_mask')
                board_train_image.image(color_transform(segm[0]), 'prediction')
                if totalloss_meter.value()[0] < 1:
                    board_loss.plot('ce_loss', -np.log(loss_ce.item() + 1e-6))
                    board_loss.plot('size_loss', -np.log(loss_size.item() + 1e-6))
                    # board_loss.plot('size_loss', -np.log(sizeloss_meter.value()[0]))
        # print('train loss:%.5f'%celoss_meter.value()[0])
        val()
def train():
    net.train()
    iou_meter = AverageValueMeter()
    loss_meter = AverageValueMeter()
    for epoch in range(max_epoch):
        iou_meter.reset()
        loss_meter.reset()
        if epoch % 5 == 0:
            for param_group in optimiser.param_groups:
                param_group['lr'] = param_group['lr'] * (0.95**(epoch // 10))
                print('learning rate:', param_group['lr'])

        for i, (img, mask, _) in tqdm(enumerate(train_loader)):
            (img, mask) = (img.cuda(),
                           mask.cuda()) if (torch.cuda.is_available()
                                            and use_cuda) else (img, mask)
            optimiser.zero_grad()
            pred = net(img)
            loss = criterion(pred, mask.squeeze(1))
            loss.backward()
            # torch.nn.utils.clip_grad_norm_(net.parameters(), 1e-3)
            optimiser.step()
            loss_meter.add(loss.item())
            iou = iou_loss(pred2segmentation(pred),
                           mask.squeeze(1).float(), class_number)[1]
            loss_meter.add(loss.item())
            iou_meter.add(iou)

            if i % train_print_frequncy == 0:
                showImages(board_train_image, img, mask,
                           pred2segmentation(pred))

        board_loss.plot('train_iou_per_epoch', iou_meter.value()[0])
        board_loss.plot('train_loss_per_epoch', loss_meter.value()[0])

        val(net, val_loader)
Exemplo n.º 17
0
    cudnn.benchmark = True

valdata = ISICdata(root=root,
                   model='train',
                   transform=True,
                   dataAugment=False,
                   equalize=Equalize)
val_loader = DataLoader(valdata,
                        batch_size=batch_size,
                        shuffle=False,
                        num_workers=number_workers,
                        pin_memory=True)

iou_meter_val = AverageValueMeter()
iou_crf_meter_val = AverageValueMeter()
iou_meter_val.reset()
iou_crf_meter_val.reset()


def graphcut_as_postprocessing(heatmap, image):

    fgmarkers = (heatmap > 0.99).astype(np.float)
    fgmarkers_ = np.zeros(shape=(*fgmarkers.shape, 3))
    for i in range(3):
        fgmarkers_[:, :, i] = fgmarkers

    bgmarkers = (heatmap < 0.01).astype(np.float)
    bgmarkers_ = np.zeros(shape=(*bgmarkers.shape, 3))
    for i in range(3):
        bgmarkers_[:, :, i] = bgmarkers
Exemplo n.º 18
0
val_loader = DataLoader(val_set, batch_size=batch_size_val, num_workers=num_workers, shuffle=True)
num_classes=2
net = UNet(num_classes=num_classes).cuda()
net.load_state_dict(torch.load('U_net_2Class.pth'))
# net.final = nn.Conv2d(64, 4, 1).cuda()
# net = Enet(num_classes=2).cuda()
optimiser = torch.optim.Adam(net.parameters(),lr=lr)
weight = torch.ones(num_classes)
# weight[0]=0
criterion = CrossEntropyLoss2d(weight.cuda()).cuda()

if __name__=="__main__":
    celoss_meter = AverageValueMeter()

    for epoch in range(max_epoch):
        celoss_meter.reset()
        if epoch %5==0:
            for param_group in optimiser.param_groups:
                param_group['lr'] = lr * (0.98 ** (epoch // 3))
                print('learning rate:', param_group['lr'])
            print('save model:')
            torch.save(net.state_dict(), 'U_net_2Class.pth')

        for i, (img,mask,_,_) in tqdm(enumerate(train_loader)):
            img,mask=img.cuda(),mask.cuda()
            optimiser.zero_grad()
            predict = net(img)
            loss = criterion(predict,mask.squeeze(1))
            segm = pred2segmentation(predict)
            loss.backward()
            optimiser.step()
Exemplo n.º 19
0
    def train(self):

        if self.net == 'vgg16':
            photo_net = DataParallel(self._get_vgg16()).cuda()
            sketch_net = DataParallel(self._get_vgg16()).cuda()
        elif self.net == 'resnet34':
            photo_net = DataParallel(self._get_resnet34()).cuda()
            sketch_net = DataParallel(self._get_resnet34()).cuda()
        elif self.net == 'resnet50':
            photo_net = DataParallel(self._get_resnet50()).cuda()
            sketch_net = DataParallel(self._get_resnet50()).cuda()

        if self.fine_tune:
            photo_net_root = self.model_root
            sketch_net_root = self.model_root.replace('photo', 'sketch')

            photo_net.load_state_dict(
                t.load(photo_net_root, map_location=t.device('cpu')))
            sketch_net.load_state_dict(
                t.load(sketch_net_root, map_location=t.device('cpu')))

        print('net')
        print(photo_net)

        # triplet_loss = nn.TripletMarginLoss(margin=self.margin, p=self.p).cuda()
        photo_cat_loss = nn.CrossEntropyLoss().cuda()
        sketch_cat_loss = nn.CrossEntropyLoss().cuda()

        my_triplet_loss = TripletLoss().cuda()

        # optimizer
        photo_optimizer = t.optim.Adam(photo_net.parameters(), lr=self.lr)
        sketch_optimizer = t.optim.Adam(sketch_net.parameters(), lr=self.lr)

        if self.vis:
            vis = Visualizer(self.env)

        triplet_loss_meter = AverageValueMeter()
        sketch_cat_loss_meter = AverageValueMeter()
        photo_cat_loss_meter = AverageValueMeter()

        data_loader = TripleDataLoader(self.dataloader_opt)
        dataset = data_loader.load_data()

        for epoch in range(self.epochs):

            print('---------------{0}---------------'.format(epoch))

            if self.test and epoch % self.test_f == 0:

                tester_config = Config()
                tester_config.test_bs = 128
                tester_config.photo_net = photo_net
                tester_config.sketch_net = sketch_net

                tester_config.photo_test = self.photo_test
                tester_config.sketch_test = self.sketch_test

                tester = Tester(tester_config)
                test_result = tester.test_instance_recall()

                result_key = list(test_result.keys())
                vis.plot('recall',
                         np.array([
                             test_result[result_key[0]],
                             test_result[result_key[1]]
                         ]),
                         legend=[result_key[0], result_key[1]])
                if self.save_model:
                    t.save(
                        photo_net.state_dict(), self.save_dir + '/photo' +
                        '/photo_' + self.net + '_%s.pth' % epoch)
                    t.save(
                        sketch_net.state_dict(), self.save_dir + '/sketch' +
                        '/sketch_' + self.net + '_%s.pth' % epoch)

            photo_net.train()
            sketch_net.train()

            for ii, data in enumerate(dataset):

                photo_optimizer.zero_grad()
                sketch_optimizer.zero_grad()

                photo = data['P'].cuda()
                sketch = data['S'].cuda()
                label = data['L'].cuda()

                p_cat, p_feature = photo_net(photo)
                s_cat, s_feature = sketch_net(sketch)

                # category loss
                p_cat_loss = photo_cat_loss(p_cat, label)
                s_cat_loss = sketch_cat_loss(s_cat, label)

                photo_cat_loss_meter.add(p_cat_loss.item())
                sketch_cat_loss_meter.add(s_cat_loss.item())

                # triplet loss
                loss = p_cat_loss + s_cat_loss

                # tri_record = 0.
                '''
                for i in range(self.batch_size):
                    # negative
                    negative_feature = t.cat([p_feature[0:i, :], p_feature[i + 1:, :]], dim=0)
                    # print('negative_feature.size :', negative_feature.size())
                    # photo_feature
                    anchor_feature = s_feature[i, :]
                    anchor_feature = anchor_feature.expand_as(negative_feature)
                    # print('anchor_feature.size :', anchor_feature.size())

                    # positive
                    positive_feature = p_feature[i, :]
                    positive_feature = positive_feature.expand_as(negative_feature)
                    # print('positive_feature.size :', positive_feature.size())

                    tri_loss = triplet_loss(anchor_feature, positive_feature, negative_feature)

                    tri_record = tri_record + tri_loss

                    # print('tri_loss :', tri_loss)
                    loss = loss + tri_loss
                '''
                # print('tri_record : ', tri_record)

                my_tri_loss = my_triplet_loss(
                    s_feature, p_feature) / (self.batch_size - 1)
                triplet_loss_meter.add(my_tri_loss.item())
                # print('my_tri_loss : ', my_tri_loss)

                # print(tri_record - my_tri_loss)
                loss = loss + my_tri_loss
                # print('loss :', loss)
                # loss = loss / opt.batch_size

                loss.backward()

                photo_optimizer.step()
                sketch_optimizer.step()

                if self.vis:
                    vis.plot('triplet_loss',
                             np.array([
                                 triplet_loss_meter.value()[0],
                                 photo_cat_loss_meter.value()[0],
                                 sketch_cat_loss_meter.value()[0]
                             ]),
                             legend=[
                                 'triplet_loss', 'photo_cat_loss',
                                 'sketch_cat_loss'
                             ])

                triplet_loss_meter.reset()
                photo_cat_loss_meter.reset()
                sketch_cat_loss_meter.reset()
Exemplo n.º 20
0
                       normalize=True)
            save_image(img_grid_gl,
                       "images/%d_gl.png" % batches_done,
                       normalize=True)

            # vis.images(imgs_lr.detach().cpu().numpy()[:1] * 0.5 + 0.5, win='imgs_lr_train')
            # vis.images(gen_hr.data.cpu().numpy()[:1] * 0.5 + 0.5, win='img_gen_train')
            # vis.images(imgs_hr.data.cpu().numpy()[:1] * 0.5 + 0.5, win='img_hr_train')
            vis.plot('loss_G_train', loss_G_meter.value()[0])
            vis.plot('loss_D_train', loss_D_meter.value()[0])
            vis.plot('loss_GAN_train', loss_GAN_meter.value()[0])
            vis.plot('loss_content_train', loss_content_meter.value()[0])
            vis.plot('loss_real_train', loss_real_meter.value()[0])
            vis.plot('loss_fake_train', loss_fake_meter.value()[0])

    loss_GAN_meter.reset()
    loss_content_meter.reset()
    loss_G_meter.reset()
    loss_real_meter.reset()
    loss_fake_meter.reset()
    loss_D_meter.reset()

    # validate the generator model
    generator.eval()
    valing_out_path = 'valing_results/SR_factor_' + str(
        opt.scale_factor) + '/' + 'epoch_' + str(epoch) + '/'
    os.makedirs(valing_out_path, exist_ok=True)

    with torch.no_grad():
        # val_bar = tqdm(val_dataloader)
        valing_results = {
Exemplo n.º 21
0
class AlphaGAN(object):
    def __init__(self, args):
        self.epoch = args.epoch
        self.batch_size = args.batch_size
        self.save_dir = args.save_dir
        self.gpu_mode = args.gpu_mode
        self.device = args.device
        self.lrG = args.lrG  #Learning Rate, Generator
        self.lrD = args.lrD  #Learning Rate Discriminator
        self.com_loss = args.com_loss  #Compositional Loss, if it does not exist, it signifies that the image is one dimensional
        self.fine_tune = args.fine_tune
        self.visual = args.visual
        self.env = args.env
        self.d_every = args.d_every
        self.g_every = args.g_every

        if self.fine_tune:
            self.model_G = args.model
            self.model_D = args.model.replace('netG', 'netD')

        # network init
        self.G = NetG()
        if self.com_loss:
            self.D = NLayerDiscriminator(input_nc=4)
        else:
            self.D = NLayerDiscriminator(input_nc=2)

        print(self.G)
        print(self.D)

        if self.fine_tune:
            self.G.load_state_dict(t.load(self.model_G))
            self.D.load_state_dict(t.load(self.model_D))

        self.G_optimizer = t.optim.Adam(self.G.parameters(), lr=self.lrG)
        self.D_optimizer = t.optim.Adam(self.D.parameters(), lr=self.lrD)
        if self.gpu_mode:
            self.G.to(self.device)
            self.D.to(self.device)
            self.G_criterion = t.nn.SmoothL1Loss().to(self.device)
            self.D_criterion = t.nn.MSELoss().to(self.device)

        self.G_error_meter = AverageValueMeter()  #Generator Loss
        self.Alpha_loss_meter = AverageValueMeter()  #Alpha Loss
        self.Com_loss_meter = AverageValueMeter()  #Compositional Loss
        self.Adv_loss_meter = AverageValueMeter()  #Adversial Loss
        self.D_error_meter = AverageValueMeter()  #Discriminator Loss

    def train(self, dataset):
        if self.visual:
            vis = Visualizer(self.env)

        for epoch in range(self.epoch):
            for ii, data in tqdm.tqdm(enumerate(dataset)):
                real_img = data['I']
                tri_img = data['T']  #Trimap

                if self.com_loss:
                    bg_img = data['B'].to(self.device)  #Background image
                    fg_img = data['F'].to(self.device)  #Foreground image

                # input to the G, 4 Channel, Image and Trimap concatenated
                input_img = t.tensor(
                    np.append(real_img.numpy(), tri_img.numpy(),
                              axis=1)).to(self.device)

                # real_alpha
                real_alpha = data['A'].to(self.device)

                # vis.images(real_img.numpy()*0.5 + 0.5, win='input_real_img')
                # vis.images(real_alpha.cpu().numpy()*0.5 + 0.5, win='real_alpha')
                # vis.images(tri_img.numpy()*0.5 + 0.5, win='tri_map')

                # train D
                if ii % self.d_every == 0:
                    self.D_optimizer.zero_grad()

                    # real_img_d = input_img[:, 0:3, :, :]
                    tri_img_d = input_img[:, 3:4, :, :]

                    #alpha
                    if self.com_loss:
                        real_d = self.D(input_img)
                    else:
                        real_d = self.D(t.cat([real_alpha, tri_img_d], dim=1))

                    target_real_label = t.tensor(1.0)  #1 for real
                    #The shape of real_d would be NxN
                    target_real = target_real_label.expand_as(real_d).to(
                        self.device)

                    loss_d_real = self.D_criterion(real_d, target_real)

                    #fake_alpha, is the predicted alpha by the generator
                    fake_alpha = self.G(input_img)
                    if self.com_loss:
                        #Constructing the fake Image
                        fake_img = fake_alpha * fg_img + (1 -
                                                          fake_alpha) * bg_img
                        fake_d = self.D(t.cat([fake_img, tri_img_d], dim=1))
                    else:
                        fake_d = self.D(t.cat([fake_alpha, tri_img_d], dim=1))
                    target_fake_label = t.tensor(0.0)

                    target_fake = target_fake_label.expand_as(fake_d).to(
                        self.device)

                    loss_d_fake = self.D_criterion(fake_d, target_fake)

                    loss_D = loss_d_real + loss_d_fake
                    #Backpropagation of the  discriminator loss
                    loss_D.backward()
                    self.D_optimizer.step()
                    self.D_error_meter.add(loss_D.item())

                # train G
                if ii % self.g_every == 0:
                    #Initialize the Optimizer
                    self.G_optimizer.zero_grad()

                    real_img_g = input_img[:, 0:3, :, :]
                    tri_img_g = input_img[:, 3:4, :, :]

                    fake_alpha = self.G(input_img)
                    # fake_alpha  is the output of the Generator
                    loss_g_alpha = self.G_criterion(fake_alpha, real_alpha)
                    #alpha_loss, difference between predicted alpha and the real alpha
                    loss_G = loss_g_alpha
                    self.Alpha_loss_meter.add(loss_g_alpha.item())

                    if self.com_loss:
                        fake_img = fake_alpha * fg_img + (1 -
                                                          fake_alpha) * bg_img
                        loss_g_cmp = self.G_criterion(
                            fake_img, real_img_g)  #Composition Loss

                        fake_d = self.D(t.cat([fake_img, tri_img_g], dim=1))
                        self.Com_loss_meter.add(loss_g_cmp.item())
                        loss_G = loss_G + loss_g_cmp

                    else:
                        fake_d = self.D(t.cat([fake_alpha, tri_img_g], dim=1))
                    target_fake = t.tensor(1.0).expand_as(fake_d).to(
                        self.device)
                    #The target of Generator is to make the Discriminator ouptut 1
                    loss_g_d = self.D_criterion(fake_d, target_fake)

                    self.Adv_loss_meter.add(loss_g_d.item())

                    loss_G = loss_G + loss_g_d

                    loss_G.backward()
                    self.G_optimizer.step()
                    self.G_error_meter.add(loss_G.item())

                if self.visual and ii % 20 == 0:
                    vis.plot('errord', self.D_error_meter.value()[0])
                    #vis.plot('errorg', self.G_error_meter.value()[0])
                    vis.plot('errorg',
                             np.array([
                                 self.Adv_loss_meter.value()[0],
                                 self.Alpha_loss_meter.value()[0],
                                 self.Com_loss_meter.value()[0]
                             ]),
                             legend=['adv_loss', 'alpha_loss', 'com_loss'])

                    vis.images(tri_img.numpy() * 0.5 + 0.5, win='tri_map')
                    vis.images(real_img.cpu().numpy() * 0.5 + 0.5,
                               win='relate_real_input')
                    vis.images(real_alpha.cpu().numpy() * 0.5 + 0.5,
                               win='relate_real_alpha')
                    vis.images(fake_alpha.detach().cpu().numpy(),
                               win='fake_alpha')
                    if self.com_loss:
                        vis.images(fake_img.detach().cpu().numpy() * 0.5 + 0.5,
                                   win='fake_img')
            self.G_error_meter.reset()
            self.D_error_meter.reset()

            self.Alpha_loss_meter.reset()
            self.Com_loss_meter.reset()
            self.Adv_loss_meter.reset()
            if epoch % 5 == 0:
                t.save(self.D.state_dict(),
                       self.save_dir + '/netD' + '/netD_%s.pth' % epoch)
                t.save(self.G.state_dict(),
                       self.save_dir + '/netG' + '/netG_%s.pth' % epoch)

        return
Exemplo n.º 22
0
    global_step = 1
    loss_metrics = AverageValueMeter()

    for epoch in range(EPOCH):
        epoch_loss = 0
        for step, (x, y) in tqdm(enumerate(train_loader)):
            output = net(x)
            train_loss = loss_func(output, y)
            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step()
            global_step = global_step + 1
            epoch_loss += train_loss.item()
            loss_metrics.add(train_loss.item())
        print("[epcho {}]:loss {}".format(epoch, loss_metrics.value()[0]))
        loss_metrics.reset()
        scheduler.step()
    test_loader = dataprocessing.getdataloader(mode=False)
    test_loss = 0
    global_step = 0
    loss_metrics.reset()
    for step, (x, y) in tqdm(enumerate(test_loader)):
        print(epoch, " global step ", global_step)
        output = net(x)
        train_loss = loss_func(output, y)
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()
        test_loss += train_loss.item()
        global_step += 1
        loss_metrics.add(train_loss.item())
Exemplo n.º 23
0
class AlphaGAN(object):
    def __init__(self, args):
        self.epoch = args.epoch
        self.batch_size = args.batch_size
        self.save_model = args.save_model
        self.save_dir = args.save_dir
        self.gpu_mode = args.gpu_mode
        self.device = args.device
        self.lrG = args.lrG
        self.lrD = args.lrD
        self.fine_tune = args.fine_tune
        self.visual = args.visual
        self.env = args.env

        if self.fine_tune:
            self.model_G = args.model
            self.model_D = args.model.replace('netG', 'netD')

        if len(self.device.split(',')) > 1:
            self.sync_bn = True
        else:
            self.sync_bn = False

        # network init
        netG = NetG(self.sync_bn)

        netD = NLayerDiscriminator(input_nc=4, n_layers=2, norm_layer=SynchronizedBatchNorm2d)

        if self.gpu_mode:
            self.G = nn.DataParallel(netG).cuda()
            self.D = nn.DataParallel(netD).cuda()
            self.G_criterion = AlphaLoss().cuda()
            self.D_criterion = t.nn.MSELoss().cuda()
        else:
            self.G = netG
            self.D = netD
            self.G_criterion = AlphaLoss()
            self.D_criterion = t.nn.MSELoss()

        if self.fine_tune:
            self.G.load_state_dict(t.load(self.model_G, map_location=t.device('cpu')))

        self.G_optimizer = t.optim.Adam(self.G.parameters(), lr=self.lrG, weight_decay=0.0005)
        # self.G_optimizer_aspp = t.optim.Adam(self.G.module.aspp.parameters(), lr=1e-4, weight_decay=0.0005)
        # self.G_optimizer_decoder = t.optim.Adam(self.G.module.decoder.parameters(), lr=1e-4, weight_decay=0.0005)
        self.D_optimizer = t.optim.Adam(self.D.parameters(), lr=self.lrD, weight_decay=0.0005)

        self.G_error_meter = AverageValueMeter()
        self.Alpha_loss_meter = AverageValueMeter()
        self.Com_loss_meter = AverageValueMeter()
        self.Adv_loss_meter = AverageValueMeter()
        self.D_error_meter = AverageValueMeter()

        self.SAD_meter = AverageValueMeter()
        self.MSE_meter = AverageValueMeter()

    def train(self, dataset):

        print('---------netG------------')
        print(self.G)
        print('---------netD------------')
        print(self.D)

        if self.visual:
            vis = Visualizer(self.env)

        for epoch in range(1, self.epoch):

            self.adjust_learning_rate(epoch)

            self.G.train()

            for ii, data in enumerate(dataset):
                t.cuda.empty_cache()

                real_img = data['I']
                tri_img = data['T']
                bg_img = data['B']
                fg_img = data['F']

                # input to the G
                input_img = t.cat([real_img, tri_img], dim=1).cuda()

                # real_alpha
                real_alpha = data['A'].cuda()

                #####################################
                # train G
                #####################################
                self.set_requires_grad([self.D], False)
                self.G_optimizer.zero_grad()

                real_img_g = input_img[:, 0:3, :, :]
                tri_img_g = input_img[:, 3:4, :, :]

                # tri_img_original = tri_img_g * 0.5 + 0.5

                fake_alpha = self.G(input_img)

                wi = t.zeros(tri_img_g.shape)
                wi[(tri_img_g * 255) == 128] = 1.
                t_wi = wi.cuda()

                unknown_size = t_wi.sum()

                fake_alpha = (1 - t_wi) * tri_img_g + t_wi * fake_alpha

                # alpha loss
                loss_g_alpha = self.G_criterion(fake_alpha, real_alpha, unknown_size)
                self.Alpha_loss_meter.add(loss_g_alpha.item())

                # compositional loss
                comp = fake_alpha * fg_img.cuda() + (1. - fake_alpha) * bg_img.cuda()
                loss_g_com = self.G_criterion(comp, real_img_g, unknown_size) / 3.
                self.Com_loss_meter.add(loss_g_com.item())

                '''
                vis.images(real_img.numpy() * 0.5 + 0.5, win='real_image', opts=dict(title='real_image'))
                vis.images(bg_img.numpy() * 0.5 + 0.5, win='bg_image', opts=dict(title='bg_image'))
                vis.images(fg_img.numpy() * 0.5 + 0.5, win='fg_image', opts=dict(title='fg_image'))
                vis.images(tri_img.numpy() * 0.5 + 0.5, win='trimap', opts=dict(title='trimap'))
                vis.images(real_alpha.detach().cpu().numpy() * 0.5 + 0.5, win='real_alpha', opts=dict(title='real_alpha'))
                vis.images(fake_alpha.detach().cpu().numpy(), win='fake_alpha', opts=dict(title='fake_alpha'))
                '''

                # trick D
                input_d = t.cat([comp, tri_img_g], dim=1)
                fake_d = self.D(input_d)

                target_fake = t.tensor(1.0).expand_as(fake_d).cuda()
                loss_g_d = self.D_criterion(fake_d, target_fake)

                self.Adv_loss_meter.add(loss_g_d.item())

                loss_G = 0.5 * loss_g_alpha + 0.5 * loss_g_com + 0.01 * loss_g_d

                loss_G.backward(retain_graph=True)
                self.G_optimizer.step()
                self.G_error_meter.add(loss_G.item())


                #########################################
                # train D
                #########################################
                self.set_requires_grad([self.D], True)
                self.D_optimizer.zero_grad()

                # real [real_img, tri]
                real_d = self.D(input_img)

                target_real_label = t.tensor(1.0)
                target_real = target_real_label.expand_as(real_d).cuda()

                loss_d_real = self.D_criterion(real_d, target_real)

                # fake [fake_img, tri]
                fake_d = self.D(input_d)
                target_fake_label = t.tensor(0.0)

                target_fake = target_fake_label.expand_as(fake_d).cuda()
                loss_d_fake = self.D_criterion(fake_d, target_fake)

                loss_D = 0.5 * (loss_d_real + loss_d_fake)
                loss_D.backward()
                self.D_optimizer.step()
                self.D_error_meter.add(loss_D.item())

                if self.visual:
                    vis.plot('errord', self.D_error_meter.value()[0])
                    vis.plot('errorg', np.array([self.Alpha_loss_meter.value()[0],
                                                 self.Com_loss_meter.value()[0]]),
                             legend=['alpha_loss', 'com_loss'])
                    vis.plot('errorg_d', self.Adv_loss_meter.value()[0])

                self.G_error_meter.reset()
                self.D_error_meter.reset()

                self.Alpha_loss_meter.reset()
                self.Com_loss_meter.reset()
                self.Adv_loss_meter.reset()

            ##############################
            # test
            ##############################

            self.G.eval()
            tester = Tester(net_G=self.G,
                            test_root='/home/zzl/dataset/Combined_Dataset/Test_set/Adobe-licensed_images')
            test_result = tester.test(vis)
            print('sad : {0}, mse : {1}'.format(test_result['sad'], test_result['mse']))
            self.SAD_meter.add(test_result['sad'])
            self.MSE_meter.add(test_result['mse'])

            vis.plot('test_result', np.array([self.SAD_meter.value()[0], self.MSE_meter.value()[0]]),
                     legend=['SAD', 'MSE'])
            if self.save_model:
                t.save(self.D.state_dict(), self.save_dir + '/netD' + '/netD_%s.pth' % epoch)
                t.save(self.G.state_dict(), self.save_dir + '/netG' + '/netG_%s.pth' % epoch)
            self.SAD_meter.reset()
            self.MSE_meter.reset()

        return

    def adjust_learning_rate(self, epoch):
        if epoch % 10 == 0:
            print('reduce learning rate')
            self.lrG = self.lrG / 10
            self.lrD = self.lrD / 10

            for param_group in self.G_optimizer.param_groups:
                param_group['lr'] = self.lrG

            for param_group in self.D_optimizer.param_groups:
                param_group['lr'] = self.lrD


    def set_requires_grad(self, nets, requires_grad=False):
        if not isinstance(nets, list):
            nets = [nets]
        for net in nets:
            if net is not None:
                for param in net.parameters():
                    param.requires_grad = requires_grad
Exemplo n.º 24
0
def train(**kwargs):
    # opt = Config()
    for k, v in kwargs.items():
        setattr(opt, k, v)
    vis = Visualizer(opt.env, opt.port)
    device = t.device('cuda') if opt.use_gpu else t.device('cpu')
    lr = opt.lr

    #网络配置
    featurenet = FeatureNet(4, 5)
    if opt.model_path:
        featurenet.load_state_dict(
            t.load(opt.model_path, map_location=lambda _s, _: _s))
    featurenet.to(device)

    #加载数据
    data_set = dataset.FeatureDataset(root=opt.data_root,
                                      train=True,
                                      test=False)
    dataloader = DataLoader(data_set,
                            batch_size=opt.batch_size,
                            shuffle=True,
                            num_workers=opt.num_workers)
    val_dataset = dataset.FeatureDataset(root=opt.data_root,
                                         train=False,
                                         test=False)
    val_dataloader = DataLoader(val_dataset,
                                opt.batch_size,
                                shuffle=False,
                                num_workers=opt.num_workers)
    #定义优化器和随时函数
    optimizer = t.optim.SGD(featurenet.parameters(), lr)
    criterion = t.nn.CrossEntropyLoss().to(device)

    #计算重要指标
    loss_meter = AverageValueMeter()

    #开始训练
    for epoch in range(opt.max_epoch):
        loss_meter.reset()
        for ii, (data, label) in enumerate(dataloader):
            feature = data.to(device)
            target = label.to(device)

            optimizer.zero_grad()
            prob = featurenet(feature)
            # print(prob)
            # print(target)
            loss = criterion(prob, target)
            loss.backward()
            optimizer.step()
            loss_meter.add(loss.item())

            if (ii + 1) % opt.plot_every:
                vis.plot('train_loss', loss_meter.value()[0])
                if os.path.exists(opt.debug_file):
                    import ipdb
                    ipdb.set_trace()
        t.save(
            featurenet.state_dict(),
            'checkpoints/{epoch}_{time}_{loss}.pth'.format(
                epoch=epoch,
                time=time.strftime('%m%d_%H_%M_%S'),
                loss=loss_meter.value()[0]))

        #验证和可视化
        accu, loss = val(featurenet, val_dataloader, criterion)
        featurenet.train()
        vis.plot('val_loss', loss)
        vis.log('epoch: {epoch}, loss: {loss}, accu: {accu}'.format(
            epoch=epoch, loss=loss, accu=accu))

        lr = lr * 0.9
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
Exemplo n.º 25
0
def train(**kwargs):
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)

    # choosing device for training
    if opt.gpu:
        device = torch.device("cuda")
        print('using GPU')
    else:
        device = torch.device('cpu')
        print('using CPU')

    # data preprocessing
    transforms = tv.transforms.Compose([
        # 3*96*96
        tv.transforms.Resize(opt.img_size
                             ),  # resize images to img_size* img_size
        tv.transforms.CenterCrop(opt.img_size),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    dataset = tv.datasets.ImageFolder(root=opt.data_path, transform=transforms)

    dataloader = DataLoader(
        dataset,  # loading dataset
        batch_size=opt.batch_size,  # setting batch size
        shuffle=True,  # choosing if shuffle or not
        num_workers=opt.num_workers,  # using multiple threads for processing
        drop_last=
        True  # if true, drop the last batch if the batch is not fitted the size of batch size
    )

    # initialize network
    netg, netd = NetG(opt), NetD(opt)
    map_location = lambda storage, loc: storage

    # torch.load for loading models
    if opt.netg_path:
        netg.load_state_dict(
            torch.load(f=opt.netg_path, map_location=map_location))
    if opt.netd_path:
        netd.load_state_dict(
            torch.load(f=opt.netd_path, map_location=map_location))

    # move models to device
    netd.to(device)
    netg.to(device)

    # Adam optimizer
    optimize_g = torch.optim.Adam(netg.parameters(),
                                  lr=opt.lr1,
                                  betas=(opt.beta1, 0.999))
    optimize_d = torch.optim.Adam(netd.parameters(),
                                  lr=opt.lr2,
                                  betas=(opt.beta1, 0.999))

    # BCEloss:-w(ylog x +(1 - y)log(1 - x))
    # y: real label,x: score from discriminator using sigmiod( 1: real, 0: fake)
    criterions = nn.BCELoss().to(device)

    # define labels
    true_labels = torch.ones(opt.batch_size).to(device)
    fake_labels = torch.zeros(opt.batch_size).to(device)

    # generate a noise with the distribution of N(1,1),dim = opt.nz,size = opt.batch_size
    noises = torch.randn(opt.batch_size, opt.nz, 1, 1).to(device)

    # for generating images when saving models
    fix_noises = torch.randn(opt.batch_size, opt.nz, 1, 1).to(device)

    errord_meter = AverageValueMeter()
    errorg_meter = AverageValueMeter()
    write = SummaryWriter(log_dir=opt.virs, comment='loss')

    # training
    for epoch in range(opt.max_epoch):
        for ii_, (img, _) in tqdm((enumerate(dataloader))):
            real_img = img.to(device)

            # begin training
            # train discriminator for every d_every batches
            if ii_ % opt.d_every == 0:
                # clear optimizer gradient
                optimize_d.zero_grad()

                output = netd(real_img)
                error_d_real = criterions(output, true_labels)
                error_d_real.backward()

                # generate fake image
                noises = noises.detach()
                # generate fake images data using noises
                fake_image = netg(noises).detach()
                # discriminator discriminate fake images
                output = netd(fake_image)
                error_d_fake = criterions(output, fake_labels)
                error_d_fake.backward()

                optimize_d.step()

                error_d = error_d_fake + error_d_real
                errord_meter.add(error_d.item())

            # train generator for every g_every batches
            if ii_ % opt.g_every == 0:
                optimize_g.zero_grad()
                noises.data.copy_(torch.randn(opt.batch_size, opt.nz, 1, 1))
                fake_image = netg(noises)
                output = netd(fake_image)
                error_g = criterions(output, true_labels)
                error_g.backward()

                optimize_g.step()

                errorg_meter.add(error_g.item())

        # draw graph of loss
        if ii_ % 5 == 0:
            write.add_scalar("Discriminator_loss", errord_meter.value()[0])
            write.add_scalar("Generator_loss", errorg_meter.value()[0])

        # saving models for save_every batches
        if (epoch + 1) % opt.save_every == 0:
            fix_fake_image = netg(fix_noises)
            tv.utils.save_image(fix_fake_image.data[:64],
                                "%s/%s.png" % (opt.save_path, epoch),
                                normalize=True)

            torch.save(netd.state_dict(),
                       'imgs3/' + 'netd_{0}.pth'.format(epoch))
            torch.save(netg.state_dict(),
                       'imgs3/' + 'netg_{0}.pth'.format(epoch))
            errord_meter.reset()
            errorg_meter.reset()

    write.close()
Exemplo n.º 26
0
def train(**kwargs):
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)

    device = t.device('cuda') if opt.gpu else t.device('cpu')
    # if opt.vis:
    #     from visualize import Visualizer
    #     vis = Visualizer(opt.env)

    # 数据
    transforms = tv.transforms.Compose([
        tv.transforms.Resize(opt.image_size),
        tv.transforms.CenterCrop(opt.image_size),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    dataset = tv.datasets.ImageFolder(root=opt.data_path, transform=transforms)
    dataloader = t.utils.data.DataLoader(dataset,
                                         batch_size=opt.batch_size,
                                         shuffle=True,
                                         num_workers=opt.num_workers,
                                         drop_last=True)

    # 网络
    netg, netd = NetG(opt), NetD(opt)
    map_location = lambda storage, loc: storage
    if opt.netd_path:
        netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
    if opt.netg_path:
        netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))
    netd.to(device)
    netg.to(device)

    # 定义优化器和损失
    optimizer_g = t.optim.Adam(netg.parameters(),
                               opt.lr1,
                               betas=(opt.beta1, 0.999))
    optimizer_d = t.optim.Adam(netd.parameters(),
                               opt.lr2,
                               betas=(opt.beta1, 0.999))
    criterion = t.nn.BCELoss().to(device)

    # 真图片label为1,假图片label为0
    # noises为生成网络的输入
    true_labels = t.ones(opt.batch_size).to(device)
    fake_labels = t.zeros(opt.batch_size).to(device)
    fix_noises = t.randn(opt.batch_size, opt.nz, 1, 1).to(device)
    noises = t.randn(opt.batch_size, opt.nz, 1, 1).to(device)

    errord_meter = AverageValueMeter()
    errorg_meter = AverageValueMeter()

    epochs = range(opt.max_epoch)
    for epoch in iter(epochs):
        for ii, (img, _) in tqdm.tqdm(enumerate(dataloader)):
            real_img = img.to(device)

            if ii % opt.d_every == 0:
                # 训练判别器
                optimizer_d.zero_grad()
                ## 尽可能的把真图片判别为正确
                output = netd(real_img)
                error_d_real = criterion(output, true_labels)
                error_d_real.backward()

                ## 尽可能把假图片判别为错误
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises).detach()  # 根据噪声生成假图
                output = netd(fake_img)
                error_d_fake = criterion(output, fake_labels)
                error_d_fake.backward()
                optimizer_d.step()

                error_d = error_d_fake + error_d_real

                errord_meter.add(error_d.item())

            if ii % opt.g_every == 0:
                # 训练生成器
                optimizer_g.zero_grad()
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises)
                output = netd(fake_img)
                error_g = criterion(output, true_labels)
                error_g.backward()
                optimizer_g.step()
                errorg_meter.add(error_g.item())

            if opt.vis and ii % opt.plot_every == opt.plot_every - 1:
                ## 可视化
                if os.path.exists(opt.debug_file):
                    ipdb.set_trace()
                fix_fake_imgs = netg(fix_noises)
                vis.images(fix_fake_imgs.detach().cpu().numpy()[:64] * 0.5 +
                           0.5,
                           win='fixfake')
                vis.images(real_img.data.cpu().numpy()[:64] * 0.5 + 0.5,
                           win='real')
                vis.plot('errord', errord_meter.value()[0])
                vis.plot('errorg', errorg_meter.value()[0])

        if (epoch + 1) % opt.save_every == 0:
            # 保存模型、图片
            tv.utils.save_image(fix_fake_imgs.data[:64],
                                '%s/%s.png' % (opt.save_path, epoch),
                                normalize=True,
                                range=(-1, 1))
            t.save(netd.state_dict(), 'checkpoints/netd_%s.pth' % epoch)
            t.save(netg.state_dict(), 'checkpoints/netg_%s.pth' % epoch)
            errord_meter.reset()
            errorg_meter.reset()
Exemplo n.º 27
0
def train(**kwargs):
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)

    device = t.device("cuda") if opt.gpu else t.device("cpu")

    # 数据处理,输出规范为-1~1
    transforms = tv.transforms.Compose([
        tv.transforms.Resize(opt.image_size),
        tv.transforms.CenterCrop(opt.image_size),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    dataset = tv.datasets.ImageFolder(opt.data_path, transform=transforms)
    dataloader = t.utils.data.DataLoader(dataset,
                                         batch_size=opt.batch_size,
                                         shuffle=True,
                                         num_workers=opt.num_workers,
                                         drop_last=True)

    # 网络
    netg, netd = NetG(opt), NetD(opt)
    map_location = lambda storage, loc: storage
    if opt.netd_path:
        netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
    if opt.netg_path:
        netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))
    netd.to(device)
    netg.to(device)

    # 定义优化器和损失
    optimizer_g = t.optim.Adam(netg.parameters(),
                               opt.lr1,
                               betas=(opt.beta1, 0.999))
    optimizer_d = t.optim.Adam(netd.parameters(),
                               opt.lr2,
                               betas=(opt.beta1, 0.999))
    criterion = t.nn.BCELoss()

    # 真图片label为1,假图片label为0, noise为生成网络的输入
    true_labels = t.ones(opt.batch_size).to(device)
    fake_labels = t.zeros(opt.batch_size).to(device)
    fix_noises = t.randn(opt.batch_size, opt.nz, 1, 1).to(device)
    noises = t.randn(opt.batch_size, opt.nz, 1, 1).to(device)

    # 用来结果的均值和标准差
    errord_meter = AverageValueMeter()
    errorg_meter = AverageValueMeter()

    epochs = range(opt.max_epoch)
    for epoch in iter(epochs):
        for ii, (img, _) in tqdm.tqdm(enumerate(dataloader)):
            real_img = img.to(device)

            if ii % opt.d_every == 0:
                # 训练判别器
                optimizer_d.zero_grad()
                ## 尽可能把真图片判别为正
                output = netd(real_img)
                error_d_real = criterion(output, true_labels)
                error_d_real.backward()

                ## 尽可能把假图片判断为错误
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                # 使用detach来关闭G求梯度,加速训练
                fake_img = netg(noises).detach()
                output = netd(fake_img)
                error_d_fake = criterion(output, fake_labels)
                error_d_fake.backward()
                optimizer_d.step()

                error_d = error_d_fake + error_d_real

                errord_meter.add(error_d.item())

            if ii % opt.g_every == 0:
                # 训练生成器
                optimizer_g.zero_grad()
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises)
                output = netd(fake_img)
                # 尽可能把假的图片也判别为1
                error_g = criterion(output, true_labels)
                error_g.backward()
                optimizer_g.step()
                errorg_meter.add(error_g.item())

            # 可视化

        # 保存模型、图片
        if (epoch + 1) % opt.save_every == 0:
            fix_fake_imgs = netg(fix_noises)
            tv.utils.save_image(fix_fake_imgs.data[:64],
                                "%s%s.png" % (opt.save_path, epoch),
                                normalize=True,
                                range=(-1, 1))
            t.save(netd.state_dict(), r"./checkpoints/netd_%s.pth" % epoch)
            t.save(netg.state_dict(), r"./checkpoints/netg_%s.pth" % epoch)
            errord_meter.reset()
            errorg_meter.reset()
Exemplo n.º 28
0
class AlphaGAN(object):
    def __init__(self, args):
        self.epoch = args.epoch
        self.batch_size = args.batch_size
        self.save_dir = args.save_dir
        self.gpu_mode = args.gpu_mode
        self.device = args.device
        self.lrG = args.lrG
        self.lrD = args.lrD
        self.com_loss = args.com_loss
        self.fine_tune = args.fine_tune
        self.visual = args.visual
        self.env = args.env
        self.d_every = args.d_every
        self.g_every = args.g_every

        if self.fine_tune:
            self.model_G = args.model
            self.model_D = args.model.replace('netG', 'netD')

        # network init
        self.G = NetG()
        if self.com_loss:
            self.D = NLayerDiscriminator(input_nc=4)
        else:
            self.D = NLayerDiscriminator(input_nc=2)

        print(self.G)
        print(self.D)

        if self.fine_tune:
            self.G.load_state_dict(t.load(self.model_G))
            self.D.load_state_dict(t.load(self.model_D))

        self.G_optimizer = t.optim.Adam(self.G.parameters(), lr=self.lrG)
        self.D_optimizer = t.optim.Adam(self.D.parameters(), lr=self.lrD)
        if self.gpu_mode:
            self.G.to(self.device)
            self.D.to(self.device)
            self.G_criterion = t.nn.SmoothL1Loss().to(self.device)
            self.D_criterion = t.nn.MSELoss().to(self.device)

        self.G_error_meter = AverageValueMeter()
        self.Alpha_loss_meter = AverageValueMeter()
        self.Com_loss_meter = AverageValueMeter()
        self.Adv_loss_meter = AverageValueMeter()
        self.D_error_meter = AverageValueMeter()

    def train(self, dataset):
        if self.visual:
            vis = Visualizer(self.env)

        for epoch in range(self.epoch):
            for ii, data in tqdm.tqdm(enumerate(dataset)):
                real_img = data['I']
                tri_img = data['T']

                if self.com_loss:
                    bg_img = data['B'].to(self.device)
                    fg_img = data['F'].to(self.device)

                # input to the G
                input_img = t.tensor(np.append(real_img.numpy(), tri_img.numpy(), axis=1)).to(self.device)

                # real_alpha
                real_alpha = data['A'].to(self.device)

                # vis.images(real_img.numpy()*0.5 + 0.5, win='input_real_img')
                # vis.images(real_alpha.cpu().numpy()*0.5 + 0.5, win='real_alpha')
                # vis.images(tri_img.numpy()*0.5 + 0.5, win='tri_map')

                # train D
                if ii % self.d_every == 0:
                    self.D_optimizer.zero_grad()

                    # real_img_d = input_img[:, 0:3, :, :]
                    tri_img_d = input_img[:, 3:4, :, :]

                    # 真正的alpha 交给判别器判断
                    if self.com_loss:
                        real_d = self.D(input_img)
                    else:
                        real_d = self.D(t.cat([real_alpha, tri_img_d], dim=1))

                    target_real_label = t.tensor(1.0)
                    target_real = target_real_label.expand_as(real_d).to(self.device)

                    loss_d_real = self.D_criterion(real_d, target_real)
                    #loss_d_real.backward()

                    # 生成器生成fake_alpha 交给判别器判断
                    fake_alpha = self.G(input_img)
                    if self.com_loss:
                        fake_img = fake_alpha*fg_img + (1 - fake_alpha) * bg_img
                        fake_d = self.D(t.cat([fake_img, tri_img_d], dim=1))
                    else:
                        fake_d = self.D(t.cat([fake_alpha, tri_img_d], dim=1))
                    target_fake_label = t.tensor(0.0)

                    target_fake = target_fake_label.expand_as(fake_d).to(self.device)

                    loss_d_fake = self.D_criterion(fake_d, target_fake)

                    loss_D = loss_d_real + loss_d_fake
                    loss_D.backward()
                    self.D_optimizer.step()
                    self.D_error_meter.add(loss_D.item())

                # train G
                if ii % self.g_every == 0:
                    self.G_optimizer.zero_grad()

                    real_img_g = input_img[:, 0:3, :, :]
                    tri_img_g = input_img[:, 3:4, :, :]

                    fake_alpha = self.G(input_img)
                    # fake_alpha 与 real_alpha的L1 loss
                    loss_g_alpha = self.G_criterion(fake_alpha, real_alpha)
                    loss_G = loss_g_alpha
                    self.Alpha_loss_meter.add(loss_g_alpha.item())

                    if self.com_loss:
                        fake_img = fake_alpha * fg_img + (1 - fake_alpha) * bg_img
                        loss_g_cmp = self.G_criterion(fake_img, real_img_g)

                        # 迷惑判别器
                        fake_d = self.D(t.cat([fake_img, tri_img_g], dim=1))
                        self.Com_loss_meter.add(loss_g_cmp.item())
                        loss_G = loss_G + loss_g_cmp

                    else:
                        fake_d = self.D(t.cat([fake_alpha, tri_img_g], dim=1))
                    target_fake = t.tensor(1.0).expand_as(fake_d).to(self.device)
                    loss_g_d = self.D_criterion(fake_d, target_fake)

                    self.Adv_loss_meter.add(loss_g_d.item())

                    loss_G = loss_G + loss_g_d

                    loss_G.backward()
                    self.G_optimizer.step()
                    self.G_error_meter.add(loss_G.item())

                if self.visual and ii % 20 == 0:
                    vis.plot('errord', self.D_error_meter.value()[0])
                    #vis.plot('errorg', self.G_error_meter.value()[0])
                    vis.plot('errorg', np.array([self.Adv_loss_meter.value()[0], self.Alpha_loss_meter.value()[0],
                                                 self.Com_loss_meter.value()[0]]), legend=['adv_loss', 'alpha_loss',
                                                                                           'com_loss'])

                    vis.images(tri_img.numpy()*0.5 + 0.5, win='tri_map')
                    vis.images(real_img.cpu().numpy() * 0.5 + 0.5, win='relate_real_input')
                    vis.images(real_alpha.cpu().numpy() * 0.5 + 0.5, win='relate_real_alpha')
                    vis.images(fake_alpha.detach().cpu().numpy(), win='fake_alpha')
                    if self.com_loss:
                        vis.images(fake_img.detach().cpu().numpy()*0.5 + 0.5, win='fake_img')
            self.G_error_meter.reset()
            self.D_error_meter.reset()

            self.Alpha_loss_meter.reset()
            self.Com_loss_meter.reset()
            self.Adv_loss_meter.reset()
            if epoch % 5 == 0:
                t.save(self.D.state_dict(), self.save_dir + '/netD' + '/netD_%s.pth' % epoch)
                t.save(self.G.state_dict(), self.save_dir + '/netG' + '/netG_%s.pth' % epoch)

        return
Exemplo n.º 29
0
class AlphaGAN(object):
    def __init__(self, args):
        self.epoch = args.epoch
        self.warmup_step = args.warmup_step
        self.batch_size = args.batch_size
        self.save_model = args.save_model
        self.save_dir = args.save_dir
        self.gpu_mode = args.gpu_mode
        self.device = args.device
        self.lrG = args.lrG
        self.lrD = args.lrD
        self.fine_tune = args.fine_tune
        self.visual = args.visual
        self.env = args.env

        self.best_sad = 1000

        if self.fine_tune:
            self.model_G = args.model
            self.model_D = args.model.replace('netG', 'netD')

        if len(self.device.split(',')) > 1:
            self.sync_bn = False
        else:
            self.sync_bn = False

        # network init
        netG = NetG(self.sync_bn)

        netD = NLayerDiscriminator(input_nc=4,
                                   n_layers=2,
                                   norm_layer=SynchronizedBatchNorm2d)

        if self.gpu_mode:
            self.G = netG.cuda()
            self.D = netD.cuda()
            self.G_criterion = AlphaLoss().cuda()
            self.D_criterion = t.nn.MSELoss().cuda()
        else:
            self.G = netG
            self.D = netD
            self.G_criterion = AlphaLoss()
            self.D_criterion = t.nn.MSELoss()

        if self.fine_tune:
            self.G.load_state_dict(
                t.load(self.model_G, map_location=t.device('cpu')))

        self.G_optimizer = t.optim.Adam(self.G.parameters(),
                                        lr=self.lrG,
                                        weight_decay=0.0005)
        # self.G_optimizer_aspp = t.optim.Adam(self.G.module.aspp.parameters(), lr=1e-4, weight_decay=0.0005)
        # self.G_optimizer_decoder = t.optim.Adam(self.G.module.decoder.parameters(), lr=1e-4, weight_decay=0.0005)
        self.D_optimizer = t.optim.Adam(self.D.parameters(),
                                        lr=self.lrD,
                                        weight_decay=0.0005)

        self.G_scheduler = lr_scheduler.CosineAnnealingLR(self.G_optimizer,
                                                          T_max=self.epoch -
                                                          self.warmup_step)
        self.D_scheduler = lr_scheduler.CosineAnnealingLR(self.D_optimizer,
                                                          T_max=self.epoch -
                                                          self.warmup_step)

        self.avg_G_loss = AverageValueMeter()
        self.avg_D_loss = AverageValueMeter()

        self.SAD_meter = AverageValueMeter()
        self.MSE_meter = AverageValueMeter()

    def train(self, dataset):

        print('---------netG------------')
        print(self.G)
        print('---------netD------------')
        print(self.D)

        writer = SummaryWriter('tensorboardlog/AlphaGAN_bs_1')
        niter = 1

        for epoch in range(1, self.epoch + 1):

            cur_lr = self.adjust_learning_rate(epoch)
            writer.add_scalar('Train/lr', cur_lr, epoch)

            self.G.train()

            for ii, data in enumerate(dataset):
                #t.cuda.empty_cache()

                real_img = data['I']
                tri_img = data['T']
                bg_img = data['B']
                fg_img = data['F']

                # input to the G
                input_img = t.cat([real_img, tri_img], dim=1).cuda()

                # real_alpha
                real_alpha = data['A'].cuda()

                #####################################
                # train G
                #####################################
                self.set_requires_grad([self.D], False)
                self.G_optimizer.zero_grad()

                real_img_g = input_img[:, 0:3, :, :]
                tri_img_g = input_img[:, 3:4, :, :]

                # tri_img_original = tri_img_g * 0.5 + 0.5

                fake_alpha = self.G(input_img)

                wi = t.zeros(tri_img_g.shape)
                wi[(tri_img_g * 255) == 128] = 1.
                t_wi = wi.cuda()

                unknown_size = t_wi.sum()

                fake_alpha = (1 - t_wi) * tri_img_g + t_wi * fake_alpha

                # alpha loss
                loss_g_alpha = self.G_criterion(fake_alpha, real_alpha,
                                                unknown_size)

                # compositional loss
                comp = fake_alpha * fg_img.cuda() + (
                    1. - fake_alpha) * bg_img.cuda()
                loss_g_com = self.G_criterion(comp, real_img_g,
                                              unknown_size) / 3.
                # self.Com_loss_meter.add(loss_g_com.item())
                '''
                vis.images(real_img.numpy() * 0.5 + 0.5, win='real_image', opts=dict(title='real_image'))
                vis.images(bg_img.numpy() * 0.5 + 0.5, win='bg_image', opts=dict(title='bg_image'))
                vis.images(fg_img.numpy() * 0.5 + 0.5, win='fg_image', opts=dict(title='fg_image'))
                vis.images(tri_img.numpy() * 0.5 + 0.5, win='trimap', opts=dict(title='trimap'))
                vis.images(real_alpha.detach().cpu().numpy() * 0.5 + 0.5, win='real_alpha', opts=dict(title='real_alpha'))
                vis.images(fake_alpha.detach().cpu().numpy(), win='fake_alpha', opts=dict(title='fake_alpha'))
                '''

                # trick D
                input_d = t.cat([comp, tri_img_g], dim=1)
                fake_d = self.D(input_d)

                target_fake = t.tensor(1.0).expand_as(fake_d).cuda()
                loss_g_d = self.D_criterion(fake_d, target_fake)

                loss_G = 0.5 * loss_g_alpha + 0.5 * loss_g_com + 0.001 * loss_g_d

                self.avg_G_loss.add(loss_G.item())

                loss_G.backward(retain_graph=True)
                self.G_optimizer.step()
                writer.add_scalar('Train/CompLoss', loss_g_com.item(), niter)
                writer.add_scalar('Train/AlphaLoss', loss_g_alpha.item(),
                                  niter)
                writer.add_scalar('Train/AdvLoss', loss_g_d.item(), niter)
                writer.add_scalar('Train/lossG', loss_G.item(), niter)

                #########################################
                # train D
                #########################################
                self.set_requires_grad([self.D], True)
                self.D_optimizer.zero_grad()

                # real [real_img, tri]
                real_d = self.D(input_img)

                target_real_label = t.tensor(1.0)
                target_real = target_real_label.expand_as(real_d).cuda()

                loss_d_real = self.D_criterion(real_d, target_real)

                # fake [fake_img, tri]
                fake_d = self.D(input_d)
                target_fake_label = t.tensor(0.0)

                target_fake = target_fake_label.expand_as(fake_d).cuda()
                loss_d_fake = self.D_criterion(fake_d, target_fake)

                loss_D = 0.5 * (loss_d_real + loss_d_fake)
                loss_D.backward()
                self.D_optimizer.step()
                self.avg_D_loss.add(loss_D.item())
                writer.add_scalar('Train/D_AdvLoss', loss_D.item(), niter)
                niter += 1
            writer.add_scalar('Train/G_avg_loss',
                              self.avg_G_loss.value()[0], epoch)
            writer.add_scalar('Train/D_avg_loss',
                              self.avg_D_loss.value()[0], epoch)
            self.avg_G_loss.reset()
            self.avg_D_loss.reset()

            ##############################
            # test
            ##############################

            if epoch % 1 == 0:
                print('-----test-------')
                tester = Tester(
                    net_G=self.G.eval(),
                    test_root=
                    '/data1/zzl/dataset/Combined_Dataset/Test_set/Adobe-licensed_images'
                )
                test_result = tester.test()
                if test_result['sad'] < self.best_sad:
                    self.best_sad = test_result['sad']
                    t.save(self.G.state_dict(),
                           self.save_dir + '/netG' + '/netG_best_sad.pth')
                print('sad : {0}, mse : {1}'.format(test_result['sad'],
                                                    test_result['mse']))
                writer.add_scalars('Test/metric', test_result, epoch)

        return

    def adjust_learning_rate(self, step):
        if step < self.warmup_step:
            cur_lr = self.warmup_lr(self.lrG, step, self.warmup_step)
            for param_group in self.G_optimizer.param_groups:
                param_group['lr'] = cur_lr
            for param_group in self.D_optimizer.param_groups:
                param_group['lr'] = cur_lr
        else:
            self.G_scheduler.step()
            self.D_scheduler.step()

            cur_lr = self.G_scheduler.get_lr()[0]

        return cur_lr

    def set_requires_grad(self, nets, requires_grad=False):
        if not isinstance(nets, list):
            nets = [nets]
        for net in nets:
            if net is not None:
                for param in net.parameters():
                    param.requires_grad = requires_grad

    def warmup_lr(self, init_lr, step, iter_num):
        return step / iter_num * init_lr
Exemplo n.º 30
0
            # If D loss is zero, then re-initialize netD
            if err_d.item() < 1e-5:
                netd.apply(weights_init)

            #--update_netg--    Update G network: log(D(G(x)))  + ||G(x) - x||
            netg.zero_grad()
            #out_g, _ = netd(fake)
            err_g_bce = criterion_L2(feat_true, feat_fake)  # l_adv
            err_g_l1l = criterion_L1(fake, img_st)  # l_con
            err_g_enc = criterion_L2(latent_i, latent_o)  # l_enc
            err_g = err_g_bce * config.w_bce + err_g_l1l * config.w_rec + err_g_enc * config.w_enc
            err_g.backward()
            optimizer_g.step()
            optimizer_f.step()
            errorg_meter.add(err_g.data.cpu().numpy())
            vis.plot('errorg', errorg_meter.value()[0])

            err_Latent = err_g_enc
            errorLatent_meter.add(err_Latent.data.cpu().numpy())
            vis.plot('errorLatent', errorLatent_meter.value()[0])
            #vis.images(((t.squeeze(fake[:,:,1,:,:],0).detach().cpu().numpy())), win='Fake')
            #vis.images(((t.squeeze(img_3d[:,:,1,:,:],0).detach().cpu().numpy())), win='Real')

            if epoch % config.adjust_lr == 0:
                t.save(net_st_fusion.state_dict(),
                       'cpkt_CUHK/netfusion_%s.pth' % epoch)
                t.save(netd.state_dict(), 'cpkt_CUHK/netd_%s.pth' % epoch)
                t.save(netg.state_dict(), 'cpkt_CUHK/netg_%s.pth' % epoch)
                errord_meter.reset()
                errorg_meter.reset()
Exemplo n.º 31
0
def train(**kwargs):
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)

    device=t.device('cuda') if opt.gpu else t.device('cpu')
    if opt.vis:
        from visualize import Visualizer
        vis = Visualizer(opt.env)

    # 数据
    transforms = tv.transforms.Compose([
        tv.transforms.Resize(opt.image_size),
        tv.transforms.CenterCrop(opt.image_size),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    dataset = tv.datasets.ImageFolder(opt.data_path, transform=transforms)
    dataloader = t.utils.data.DataLoader(dataset,
                                         batch_size=opt.batch_size,
                                         shuffle=True,
                                         num_workers=opt.num_workers,
                                         drop_last=True
                                         )

    # 网络
    netg, netd = NetG(opt), NetD(opt)
    map_location = lambda storage, loc: storage
    if opt.netd_path:
        netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
    if opt.netg_path:
        netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))
    netd.to(device)
    netg.to(device)


    # 定义优化器和损失
    optimizer_g = t.optim.Adam(netg.parameters(), opt.lr1, betas=(opt.beta1, 0.999))
    optimizer_d = t.optim.Adam(netd.parameters(), opt.lr2, betas=(opt.beta1, 0.999))
    criterion = t.nn.BCELoss().to(device)

    # 真图片label为1,假图片label为0
    # noises为生成网络的输入
    true_labels = t.ones(opt.batch_size).to(device)
    fake_labels = t.zeros(opt.batch_size).to(device)
    fix_noises = t.randn(opt.batch_size, opt.nz, 1, 1).to(device)
    noises = t.randn(opt.batch_size, opt.nz, 1, 1).to(device)

    errord_meter = AverageValueMeter()
    errorg_meter = AverageValueMeter()


    epochs = range(opt.max_epoch)
    for epoch in iter(epochs):
        for ii, (img, _) in tqdm.tqdm(enumerate(dataloader)):
            real_img = img.to(device)

            if ii % opt.d_every == 0:
                # 训练判别器
                optimizer_d.zero_grad()
                ## 尽可能的把真图片判别为正确
                output = netd(real_img)
                error_d_real = criterion(output, true_labels)
                error_d_real.backward()

                ## 尽可能把假图片判别为错误
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises).detach()  # 根据噪声生成假图
                output = netd(fake_img)
                error_d_fake = criterion(output, fake_labels)
                error_d_fake.backward()
                optimizer_d.step()

                error_d = error_d_fake + error_d_real

                errord_meter.add(error_d.item())

            if ii % opt.g_every == 0:
                # 训练生成器
                optimizer_g.zero_grad()
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises)
                output = netd(fake_img)
                error_g = criterion(output, true_labels)
                error_g.backward()
                optimizer_g.step()
                errorg_meter.add(error_g.item())

            if opt.vis and ii % opt.plot_every == opt.plot_every - 1:
                ## 可视化
                if os.path.exists(opt.debug_file):
                    ipdb.set_trace()
                fix_fake_imgs = netg(fix_noises)
                vis.images(fix_fake_imgs.detach().cpu().numpy()[:64] * 0.5 + 0.5, win='fixfake')
                vis.images(real_img.data.cpu().numpy()[:64] * 0.5 + 0.5, win='real')
                vis.plot('errord', errord_meter.value()[0])
                vis.plot('errorg', errorg_meter.value()[0])

        if (epoch+1) % opt.save_every == 0:
            # 保存模型、图片
            tv.utils.save_image(fix_fake_imgs.data[:64], '%s/%s.png' % (opt.save_path, epoch), normalize=True,
                                range=(-1, 1))
            t.save(netd.state_dict(), 'checkpoints/netd_%s.pth' % epoch)
            t.save(netg.state_dict(), 'checkpoints/netg_%s.pth' % epoch)
            errord_meter.reset()
            errorg_meter.reset()