예제 #1
0
class Solver(object):
    def __init__(self, train_loader, val_loader, test_dataset, config):
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_dataset = test_dataset
        self.config = config
        self.beta = math.sqrt(0.3)  # for max F_beta metric
        # inference: choose the side map (see paper)
        self.select = [1, 2, 3, 6]
        self.device = torch.device('cpu')
        self.mean = torch.Tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        self.std = torch.Tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
        if self.config.cuda:
            cudnn.benchmark = True
            self.device = torch.device('cuda:0')
        if config.visdom:
            self.visual = Viz_visdom("DSS 12-6-19", 1)
        self.build_model()
        if self.config.pre_trained:
            self.net.load_state_dict(torch.load(self.config.pre_trained))
        if config.mode == 'train':
            self.log_output = open("%s/logs/log.txt" % config.save_fold, 'w')
        else:
            self.net.load_state_dict(torch.load(self.config.model))
            self.net.eval()
            self.test_output = open("%s/test.txt" % config.test_fold, 'w')
            self.transform = transforms.Compose([
                transforms.Resize((256, 256)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
            ])

    # print the network information and parameter numbers
    def print_network(self, model, name):
        num_params = 0
        for p in model.parameters():
            if p.requires_grad: num_params += p.numel()
        print(name)
        print(model)
        print("The number of parameters: {}".format(num_params))

    # build the network
    def build_model(self):
        self.net = build_model().to(self.device)
        if self.config.mode == 'train': self.loss = Loss().to(self.device)
        self.net.train()
        self.net.apply(weights_init)
        if self.config.load == '':
            self.net.base.load_state_dict(torch.load(self.config.vgg))
        if self.config.load != '':
            self.net.load_state_dict(torch.load(self.config.load))
        self.optimizer = Adam(self.net.parameters(), self.config.lr)
        self.print_network(self.net, 'DSS')

    # update the learning rate
    def update_lr(self, lr):
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

    # evaluate MAE (for test or validation phase)
    def eval_mae(self, y_pred, y):
        return torch.abs(y_pred - y).mean()

    # TODO: write a more efficient version
    # get precisions and recalls: threshold---divided [0, 1] to num values
    def eval_pr(self, y_pred, y, num):
        prec, recall = torch.zeros(num), torch.zeros(num)
        thlist = torch.linspace(0, 1 - 1e-10, num)
        for i in range(num):
            y_temp = (y_pred >= thlist[i]).float()
            tp = (y_temp * y).sum()
            prec[i], recall[i] = tp / (y_temp.sum() + 1e-20), tp / y.sum()
        return prec, recall

    # validation: using resize image, and only evaluate the MAE metric
    def validation(self):
        avg_mae = 0.0
        self.net.eval()
        with torch.no_grad():
            for i, data_batch in enumerate(self.val_loader):
                #images, labels = data_batch
                images, labels = data_batch['image'], data_batch['label']
                images = images.type(torch.cuda.FloatTensor)
                labels = labels.type(torch.cuda.FloatTensor)
                images, labels = images.to(self.device), labels.to(self.device)
                prob_pred = self.net(images)
                prob_pred = torch.mean(torch.cat(
                    [prob_pred[i] for i in self.select], dim=1),
                                       dim=1,
                                       keepdim=True)
                avg_mae += self.eval_mae(prob_pred, labels).item()
                print("Average Mae" + str(avg_mae))
        self.net.train()
        return avg_mae / len(self.val_loader)

    # test phase: using origin image size, evaluate MAE and max F_beta metrics
    def test(self, num, use_crf=False):
        if use_crf: from tools.crf_process import crf
        avg_mae, img_num = 0.0, len(self.test_dataset)
        avg_prec, avg_recall = torch.zeros(num), torch.zeros(num)
        with torch.no_grad():
            for i, data in enumerate(
                    self.test_dataset
            ):  #(img, labels) in enumerate(self.test_dataset):
                images, labels = data['image'], data['label']
                images = images.type(torch.cuda.FloatTensor)
                labels = labels.type(torch.cuda.FloatTensor)
                #images = self.transform(img).unsqueeze(0)
                #labels = labels.unsqueeze(0)
                shape = labels.size()[2:]
                #print(shape)
                images = images.to(self.device)
                labels = labels.to(self.device)
                prob_pred = self.net(images)

                prob_pred = torch.mean(torch.cat(
                    [prob_pred[i] for i in self.select], dim=1),
                                       dim=1,
                                       keepdim=True)
                prob_pred = F.interpolate(prob_pred,
                                          size=shape,
                                          mode='bilinear',
                                          align_corners=True).cpu().data
                print(prob_pred[0].size())
                result_dir = 'C:/Users/Paul Vincent Nonat/Documents/Graduate Student Files/results/'
                save_image(prob_pred[0],
                           result_dir + 'result' + str(i) + '.png')
                if use_crf:
                    prob_pred = crf(img, prob_pred.numpy(), to_tensor=True)
                mae = self.eval_mae(prob_pred, labels)
                prec, recall = self.eval_pr(prob_pred, labels, num)
                print(num)
                print("[%d] mae: %.4f" % (i, mae))
                print("[%d] mae: %.4f" % (i, mae), file=self.test_output)
                avg_mae += mae
                avg_prec, avg_recall = avg_prec + prec, avg_recall + recall
        avg_mae, avg_prec, avg_recall = avg_mae / img_num, avg_prec / img_num, avg_recall / img_num
        score = (1 + self.beta**2) * avg_prec * avg_recall / (
            self.beta**2 * avg_prec + avg_recall)
        score[score != score] = 0  # delete the nan
        print('average mae: %.4f, max fmeasure: %.4f' % (avg_mae, score.max()))
        print('average mae: %.4f, max fmeasure: %.4f' % (avg_mae, score.max()),
              file=self.test_output)

    # training phase
    def train(self, num):
        iter_num = len(self.train_loader.dataset) // self.config.batch_size
        best_mae = 1.0 if self.config.val else None
        for epoch in range(self.config.epoch):
            loss_epoch = 0
            for i, data_batch in enumerate(self.train_loader):
                x, y = data_batch['image'], data_batch['label']
                x = x.type(torch.cuda.FloatTensor)
                y = y.type(torch.cuda.FloatTensor)
                x, y = Variable(x.to(self.device),
                                requires_grad=False), Variable(
                                    y.to(self.device), requires_grad=False)
                #x, y = x.to(self.device), y.to(self.device)
                if (i + 1) > iter_num: break
                self.net.zero_grad()
                y_pred = self.net(x)
                loss = self.loss(y_pred, y)
                loss.backward()
                utils.clip_grad_norm_(self.net.parameters(),
                                      self.config.clip_gradient)
                # utils.clip_grad_norm(self.loss.parameters(), self.config.clip_gradient)
                self.optimizer.step()
                loss_epoch += loss.item()
                print('epoch: [%d/%d], iter: [%d/%d], loss: [%.4f]' %
                      (epoch, self.config.epoch, i, iter_num, loss.item()))
                if self.config.visdom:
                    error = OrderedDict([('loss:', loss.item())])
                    self.visual.plot_current_errors('Cross Entropy Loss',
                                                    epoch, i / iter_num, error)

            if (epoch + 1) % self.config.epoch_show == 0:
                print('epoch: [%d/%d], epoch_loss: [%.4f]' %
                      (epoch, self.config.epoch, loss_epoch / iter_num),
                      file=self.log_output)
                if self.config.visdom:
                    avg_err = OrderedDict([('avg_loss', loss_epoch / iter_num)
                                           ])
                    self.visual.plot_current_errors('Average Loss per Epoch',
                                                    epoch, i / iter_num,
                                                    avg_err, 1)
                    for i in self.select:
                        y_show = torch.mean(torch.cat(
                            [y_pred[i] for i in self.select], dim=1),
                                            dim=1,
                                            keepdim=True)
                        img = OrderedDict([('origin' + str(epoch) + str(i),
                                            x.cpu()[0] * self.std + self.mean),
                                           ('label' + str(epoch) + str(i),
                                            y.cpu()[0][0]),
                                           ('pred_label' + str(epoch) + str(i),
                                            y_pred[i].cpu().data[0][0])])
                        self.visual.plot_current_img(img)
#this shows the mean prediction of the 5 output layers.

            if self.config.val and (epoch + 1) % self.config.epoch_val == 0:
                mae = self.validation()
                prec, recall = self.eval_pr(prob_pred, labels, num)
                score = (1 + self.beta**2) * prec * recall / (
                    self.beta**2 * prec + recall)
                score[score != score] = 0  # delete the nan
                print('--- Best MAE: %.2f, Curr MAE: %.2f ---' %
                      (best_mae, mae))
                print('--- Best MAE: %.2f, Curr MAE: %.2f ---' %
                      (best_mae, mae),
                      file=self.log_output)
                if self.config.visdom:
                    error = OrderedDict([('MAE:', mae)])
                    self.visual.plot_current_errors(
                        'Mean Absolute Error Graph', epoch, i / iter_num,
                        error, 2)

                    prec_graph = OrderedDict([('Precission:', prec)])
                    self.visual.plot_current_errors('Precission Graph', epoch,
                                                    i / iter_num, prec_graph,
                                                    3)

                    recall_graph = OrderedDict([('Recall:', recall)])
                    self.visual.plot_current_errors('Recall Graph', epoch,
                                                    i / iter_num, recall_graph,
                                                    4)

                    fscore_graph = OrderedDict([('F-Measure:', score)])
                    self.visual.plot_current_errors('F-Measure Graph', epoch,
                                                    i / iter_num, fscore_graph,
                                                    5)
                if best_mae > mae:
                    best_mae = mae
                    torch.save(self.net.state_dict(),
                               '%s/models/best.pth' % self.config.save_fold)
            if (epoch + 1) % self.config.epoch_save == 0:
                torch.save(
                    self.net.state_dict(), '%s/models/epoch_%d.pth' %
                    (self.config.save_fold, epoch + 1))
        torch.save(self.net.state_dict(),
                   '%s/models/final.pth' % self.config.save_fold)
예제 #2
0
class Solver(object):
    def __init__(self, train_loader, val_loader, test_loader, config):
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        self.config = config
        self.mean = torch.Tensor([123.68, 116.779, 103.939]).view(3, 1,
                                                                  1) / 255
        self.beta = 0.3
        self.device = torch.device('cpu')
        if self.config.cuda:
            cudnn.benchmark = True
            self.device = torch.device('cuda')
        if config.visdom:
            self.visual = Viz_visdom("NLDF", 1)
        self.build_model()
        if self.config.pre_trained:
            self.net.load_state_dict(torch.load(self.config.pre_trained))
        if config.mode == 'train':
            self.log_output = open("%s/logs/log.txt" % config.save_fold, 'w')
        else:
            self.net.load_state_dict(torch.load(self.config.model))
            self.net.eval()
            self.test_output = open("%s/test.txt" % config.test_fold, 'w')

    def print_network(self, model, name):
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()
        print(name)
        print(model)
        print("The number of parameters: {}".format(num_params))

    def build_model(self):
        self.net = build_model()
        if self.config.mode == 'train':
            self.loss = Loss(self.config.area, self.config.boundary)
        self.net = self.net.to(self.device)
        if self.config.cuda and self.config.mode == 'train':
            self.loss = self.loss.cuda()
        self.net.train()
        self.net.apply(weights_init)
        if self.config.load == '':
            self.net.base.load_state_dict(torch.load(self.config.vgg))
        if self.config.load != '':
            self.net.load_state_dict(torch.load(self.config.load))
        self.optimizer = Adam(self.net.parameters(), self.config.lr)
        self.print_network(self.net, 'NLDF')

    def update_lr(self, lr):
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

    def clip(self, y):
        return torch.clamp(y, 0.0, 1.0)

    def eval_mae(self, y_pred, y):
        return torch.abs(y_pred - y).mean()

    # TODO: write a more efficient version
    def eval_pr(self, y_pred, y, num):
        prec, recall = torch.zeros(num), torch.zeros(num)
        thlist = torch.linspace(0, 1 - 1e-10, num)
        for i in range(num):
            y_temp = (y_pred >= thlist[i]).float()
            tp = (y_temp * y).sum()
            prec[i], recall[i] = tp / (y_temp.sum() + 1e-20), tp / y.sum()
        return prec, recall

    def validation(self):
        avg_mae = 0.0
        self.net.eval()
        for i, data_batch in enumerate(self.val_loader):
            with torch.no_grad():
                images, labels = data_batch
                images, labels = images.to(self.device), labels.to(self.device)
                prob_pred = self.net(images)
            avg_mae += self.eval_mae(prob_pred, labels).cpu().item()
        self.net.train()
        return avg_mae / len(self.val_loader)

    def test(self, num):
        avg_mae, img_num = 0.0, len(self.test_loader)
        avg_prec, avg_recall = torch.zeros(num), torch.zeros(num)
        for i, data_batch in enumerate(self.test_loader):
            with torch.no_grad():
                images, labels = data_batch
                shape = labels.size()[2:]
                images = images.to(self.device)
                prob_pred = F.interpolate(self.net(images),
                                          size=shape,
                                          mode='bilinear',
                                          align_corners=True).cpu()
            mae = self.eval_mae(prob_pred, labels)
            prec, recall = self.eval_pr(prob_pred, labels, num)
            print("[%d] mae: %.4f" % (i, mae))
            print("[%d] mae: %.4f" % (i, mae), file=self.test_output)
            avg_mae += mae
            avg_prec, avg_recall = avg_prec + prec, avg_recall + recall
        avg_mae, avg_prec, avg_recall = avg_mae / img_num, avg_prec / img_num, avg_recall / img_num
        score = (1 + self.beta**2) * avg_prec * avg_recall / (
            self.beta**2 * avg_prec + avg_recall)
        score[score != score] = 0  # delete the nan
        print('average mae: %.4f, max fmeasure: %.4f' % (avg_mae, score.max()))
        print('average mae: %.4f, max fmeasure: %.4f' % (avg_mae, score.max()),
              file=self.test_output)

    def train(self):
        iter_num = len(self.train_loader.dataset) // self.config.batch_size
        best_mae = 1.0 if self.config.val else None
        for epoch in range(self.config.epoch):
            loss_epoch = 0
            for i, data_batch in enumerate(self.train_loader):
                if (i + 1) > iter_num: break
                self.net.zero_grad()
                x, y = data_batch
                x, y = x.to(self.device), y.to(self.device)
                y_pred = self.net(x)
                loss = self.loss(y_pred, y)
                loss.backward()
                utils.clip_grad_norm_(self.net.parameters(),
                                      self.config.clip_gradient)
                self.optimizer.step()
                loss_epoch += loss.cpu().item()
                print(
                    'epoch: [%d/%d], iter: [%d/%d], loss: [%.4f]' %
                    (epoch, self.config.epoch, i, iter_num, loss.cpu().item()))
                if self.config.visdom:
                    error = OrderedDict([('loss:', loss.cpu().item())])
                    self.visual.plot_current_errors(epoch, i / iter_num, error)
            if (epoch + 1) % self.config.epoch_show == 0:
                print('epoch: [%d/%d], epoch_loss: [%.4f]' %
                      (epoch, self.config.epoch, loss_epoch / iter_num),
                      file=self.log_output)
                if self.config.visdom:
                    avg_err = OrderedDict([('avg_loss', loss_epoch / iter_num)
                                           ])
                    self.visual.plot_current_errors(epoch, i / iter_num,
                                                    avg_err, 1)
                    img = OrderedDict([('origin', self.mean + x.cpu()[0]),
                                       ('label', y.cpu()[0][0]),
                                       ('pred_label', y_pred.cpu()[0][0])])
                    self.visual.plot_current_img(img)
            if self.config.val and (epoch + 1) % self.config.epoch_val == 0:
                mae = self.validation()
                print('--- Best MAE: %.4f, Curr MAE: %.4f ---' %
                      (best_mae, mae))
                print('--- Best MAE: %.4f, Curr MAE: %.4f ---' %
                      (best_mae, mae),
                      file=self.log_output)
                if best_mae > mae:
                    best_mae = mae
                    torch.save(self.net.state_dict(),
                               '%s/models/best.pth' % self.config.save_fold)
            if (epoch + 1) % self.config.epoch_save == 0:
                torch.save(
                    self.net.state_dict(), '%s/models/epoch_%d.pth' %
                    (self.config.save_fold, epoch + 1))
        torch.save(self.net.state_dict(),
                   '%s/models/final.pth' % self.config.save_fold)
예제 #3
0
class Solver(object):
    def __init__(self, train_loader, val_loader, test_loader, config):
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        self.config = config
        self.beta = 0.3  # for max F_beta metric
        self.mean = torch.Tensor([123.68, 116.779, 103.939]).view(3, 1,
                                                                  1) / 255
        # inference: choose the side map (see paper)
        self.select = [1, 2, 3, 6]
        if config.visdom:
            self.visual = Viz_visdom("DSS", 1)
        self.build_model()
        if self.config.pre_trained:
            self.net.load_state_dict(torch.load(self.config.pre_trained))
        if config.mode == 'train':
            self.log_output = open("%s/logs/log.txt" % config.save_fold, 'w')
        else:
            self.net.load_state_dict(torch.load(self.config.model))
            self.net.eval()
            self.test_output = open("%s/test.txt" % config.test_fold, 'w')

    # print the network information and parameter numbers
    def print_network(self, model, name):
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()
        print(name)
        print(model)
        print("The number of parameters: {}".format(num_params))

    # build the network
    def build_model(self):
        self.net = build_model()
        if self.config.mode == 'train': self.loss = Loss()
        if self.config.cuda: self.net = self.net.cuda()
        if self.config.cuda and self.config.mode == 'train':
            self.loss = self.loss.cuda()
        self.net.train()
        self.net.apply(weights_init)
        if self.config.load == '':
            self.net.base.load_state_dict(torch.load(self.config.vgg))
        if self.config.load != '':
            self.net.load_state_dict(torch.load(self.config.load))
        self.optimizer = Adam(self.net.parameters(), self.config.lr)
        self.print_network(self.net, 'DSS')

    # update the learning rate
    def update_lr(self, lr):
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

    # evaluate MAE (for test or validation phase)
    def eval_mae(self, y_pred, y):
        return torch.abs(y_pred - y).mean()

    # TODO: write a more efficient version
    # get precisions and recalls: threshold---divided [0, 1] to num values
    def eval_pr(self, y_pred, y, num):
        prec, recall = torch.zeros(num), torch.zeros(num)
        thlist = torch.linspace(0, 1 - 1e-10, num)
        for i in range(num):
            y_temp = (y_pred >= thlist[i]).float()
            tp = (y_temp * y).sum()
            prec[i], recall[i] = tp / (y_temp.sum() + 1e-20), tp / y.sum()
        return prec, recall

    # validation: using resize image, and only evaluate the MAE metric
    def validation(self):
        avg_mae = 0.0
        self.net.eval()
        for i, data_batch in enumerate(self.val_loader):
            images, labels = data_batch
            images, labels = Variable(images,
                                      volatile=True), Variable(labels,
                                                               volatile=True)
            if self.config.cuda:
                images, labels = images.cuda(), labels.cuda()
            prob_pred = self.net(images)
            prob_pred = torch.mean(torch.cat(
                [prob_pred[i] for i in self.select], dim=1),
                                   dim=1,
                                   keepdim=True)
            avg_mae += self.eval_mae(prob_pred, labels).cpu().data[0]
        self.net.train()
        return avg_mae / len(self.val_loader)

    # test phase: using origin image size, evaluate MAE and max F_beta metrics
    def test(self, num):
        avg_mae, img_num = 0.0, len(self.test_loader)
        avg_prec, avg_recall = torch.zeros(num), torch.zeros(num)
        for i, data_batch in enumerate(self.test_loader):
            images, labels = data_batch
            shape = labels.size()[2:]
            images = Variable(images, volatile=True)
            if self.config.cuda:
                images = images.cuda()
            prob_pred = self.net(images)
            prob_pred = torch.mean(torch.cat(
                [prob_pred[i] for i in self.select], dim=1),
                                   dim=1,
                                   keepdim=True)
            prob_pred = F.upsample(prob_pred, size=shape,
                                   mode='bilinear').cpu().data
            mae = self.eval_mae(prob_pred, labels)
            prec, recall = self.eval_pr(prob_pred, labels, num)
            print("[%d] mae: %.4f" % (i, mae))
            print("[%d] mae: %.4f" % (i, mae), file=self.test_output)
            avg_mae += mae
            avg_prec, avg_recall = avg_prec + prec, avg_recall + recall
        avg_mae, avg_prec, avg_recall = avg_mae / img_num, avg_prec / img_num, avg_recall / img_num
        score = (1 + self.beta**2) * avg_prec * avg_recall / (
            self.beta**2 * avg_prec + avg_recall)
        score[score != score] = 0  # delete the nan
        print('average mae: %.4f, max fmeasure: %.4f' % (avg_mae, score.max()))
        print('average mae: %.4f, max fmeasure: %.4f' % (avg_mae, score.max()),
              file=self.test_output)

    # training phase
    def train(self):
        x = torch.FloatTensor(self.config.batch_size, self.config.n_color,
                              self.config.img_size, self.config.img_size)
        y = torch.FloatTensor(self.config.batch_size, self.config.n_color,
                              self.config.img_size, self.config.img_size)
        if self.config.cuda:
            cudnn.benchmark = True
            x, y = x.cuda(), y.cuda()
        x, y = Variable(x), Variable(y)
        iter_num = len(self.train_loader.dataset) // self.config.batch_size
        best_mae = 1.0 if self.config.val else None
        for epoch in range(self.config.epoch):
            loss_epoch = 0
            for i, data_batch in enumerate(self.train_loader):
                if (i + 1) > iter_num: break
                self.net.zero_grad()
                images, labels = data_batch
                if self.config.cuda:
                    images, labels = images.cuda(), labels.cuda()
                x.data.resize_as_(images).copy_(images)
                y.data.resize_as_(labels).copy_(labels)
                y_pred = self.net(x)
                loss = self.loss(y_pred, y)
                loss.backward()
                utils.clip_grad_norm(self.net.parameters(),
                                     self.config.clip_gradient)
                # utils.clip_grad_norm(self.loss.parameters(), self.config.clip_gradient)
                self.optimizer.step()
                loss_epoch += loss.cpu().data[0]
                print('epoch: [%d/%d], iter: [%d/%d], loss: [%.4f]' %
                      (epoch, self.config.epoch, i, iter_num,
                       loss.cpu().data[0]))
                if self.config.visdom:
                    error = OrderedDict([('loss:', loss.cpu().data[0])])
                    self.visual.plot_current_errors(epoch, i / iter_num, error)
            if (epoch + 1) % self.config.epoch_show == 0:
                print('epoch: [%d/%d], epoch_loss: [%.4f]' %
                      (epoch, self.config.epoch, loss_epoch / iter_num),
                      file=self.log_output)
                if self.config.visdom:
                    avg_err = OrderedDict([('avg_loss', loss_epoch / iter_num)
                                           ])
                    self.visual.plot_current_errors(epoch, i / iter_num,
                                                    avg_err, 1)
                    y_show = torch.mean(torch.cat(
                        [y_pred[i] for i in self.select], dim=1),
                                        dim=1,
                                        keepdim=True)
                    img = OrderedDict([('origin', self.mean + images.cpu()[0]),
                                       ('label', labels.cpu()[0][0]),
                                       ('pred_label', y_show.cpu().data[0][0])
                                       ])
                    self.visual.plot_current_img(img)
            if self.config.val and (epoch + 1) % self.config.epoch_val == 0:
                mae = self.validation()
                print('--- Best MAE: %.2f, Curr MAE: %.2f ---' %
                      (best_mae, mae))
                print('--- Best MAE: %.2f, Curr MAE: %.2f ---' %
                      (best_mae, mae),
                      file=self.log_output)
                if best_mae > mae:
                    best_mae = mae
                    torch.save(self.net.state_dict(),
                               '%s/models/best.pth' % self.config.save_fold)
            if (epoch + 1) % self.config.epoch_save == 0:
                torch.save(
                    self.net.state_dict(), '%s/models/epoch_%d.pth' %
                    (self.config.save_fold, epoch + 1))
        torch.save(self.net.state_dict(),
                   '%s/models/final.pth' % self.config.save_fold)
예제 #4
0
class Solver(object):
    def __init__(self, train_loader, val_loader, test_dataset, config):
        self.train_loader = train_loader
        self.val_loader   = val_loader
        self.test_dataset = test_dataset
        self.config       = config
        self.beta         = math.sqrt(0.3)
        self.device = torch.device('cpu')
        self.mean   = torch.Tensor([0.485, 0.456, 0.406]).view(3, 1, 1)#view()函数作用是将一个多行的Tensor,拼接成某种行
        self.std    = torch.Tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
        self.visual_save_fold = config.pre_map
        if self.config.cuda:
            cudnn.benchmark = True
            self.device     = torch.device('cuda:0')
        if config.visdom:
            self.visual = Viz_visdom("camu", 1) 
        self.build_model()
        if self.config.pre_trained: self.net.load_state_dict(torch.load(self.config.pre_trained))
        if config.mode == 'train':
            self.log_output  = open("%s/logs/log.txt" % config.save_fold, 'w')
        else:
            self.net.load_state_dict(torch.load(self.config.model))
            self.net.eval()
            self.test_output = open("%s/test.txt" % config.test_fold, 'w')
            self.transform   = transforms.Compose([
                transforms.Resize((256, 256)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])

    # build the network
    def build_model(self):
        self.net = build_model().to(self.device)
        if self.config.mode == 'train': self.loss = Loss().to(self.device)
        self.net.train()
        self.net.eval()
        params_dict  = dict(self.net.named_parameters())
        self.optimizer = Adam(self.net.parameters(), self.config.lr)
        #self.optimizer = Adam(filter(lambda p: p.requires_grad, self.net.parameters()), lr=self.config.lr)

    # evaluate MAE (for test or validation phase)
    def eval_mae(self, y_pred, y):

        return torch.abs(y_pred - y).mean()

    # TODO: write a more efficient version
    # get precisions and recalls: threshold---divided [0, 1] to num values
    def eval_pr(self, y_pred, y, num):
        prec, recall = torch.zeros(num), torch.zeros(num)
        thlist       = torch.linspace(0, 1 - 1e-10, num)
        
        for i in range(num):
            y_temp   = (y_pred >= thlist[i]).float()
            tp       = (y_temp * y).sum()
            prec[i], recall[i] = tp / (y_temp.sum() + 1e-20), tp / y.sum()
        return prec, recall

    # validation: using resize image, and only evaluate the MAE metric
    def validation(self):
        avg_mae = 0.0
        self.net.eval()
        with torch.no_grad():
            for i, data_batch in enumerate(self.val_loader):
                images, labels = data_batch
                images, labels = images.to(self.device), labels.to(self.device)
                prob_pred      = self.net(images)
                avg_mae       += self.eval_mae(prob_pred[0], labels).item()
        self.net.train()
        return avg_mae / len(self.val_loader)

    # test phase: using origin image size, evaluate MAE and max F_beta metrics
    def test(self, num, use_crf=False):
        if use_crf: from tools.crf_process import crf
        avg_mae, img_num = 0.0, len(self.test_dataset)
        avg_prec, avg_recall = torch.zeros(num), torch.zeros(num)
        with torch.no_grad():
            for i, (img, labels, name) in enumerate(self.test_dataset):

                images = self.transform(img).unsqueeze(0)
                labels = labels.unsqueeze(0)
                shape  = labels.size()[2:]
                images = images.to(self.device)
                
                prob_pred = self.net(images)  # 因为输出多个 测试的时候需要改一下

                prob_pred = F.interpolate(prob_pred[0], size=shape, mode='bilinear', align_corners=True).cpu().data
   
                if not os.path.exists('{}/'.format(self.visual_save_fold)):
                    os.mkdir('{}/'.format(self.visual_save_fold))
             
                img_save = prob_pred.numpy()
                img_save = img_save.reshape(-1,img_save.shape[2], img_save.shape[3]).transpose(1,2,0) * 255
                cv2.imwrite('{}/{}.png'.format(self.visual_save_fold,name), img_save.astype(np.uint8))

                mae           = self.eval_mae(prob_pred, labels)
                prec, recall  = self.eval_pr(prob_pred, labels, num)
                print("[%d] mae: %.4f" % (i, mae))
                print("[%d] mae: %.4f" % (i, mae), file=self.test_output)
                avg_mae += mae
                avg_prec, avg_recall = avg_prec + prec, avg_recall + recall
        avg_mae, avg_prec, avg_recall = avg_mae / img_num, avg_prec / img_num, avg_recall / img_num
        print('average mae: %.4f' % (avg_mae))
        print('average mae: %.4f' % (avg_mae), file=self.test_output)

    # training phase
    def train(self):
        iter_num = len(self.train_loader.dataset) / self.config.batch_size
        best_mae = 1.0 if self.config.val else None

        for epoch in range(self.config.epoch):   
            loss_epoch = 0
            for i, data_batch in enumerate(self.train_loader):
                if (i + 1) > iter_num:
                    break
                self.net.zero_grad()
                x, y     = data_batch
                x, y     = x.to(self.device), y.to(self.device)
                y_pred   = self.net(x)
                loss     = self.loss(y_pred, y)
                loss.backward()
                utils.clip_grad_norm_(self.net.parameters(), self.config.clip_gradient)
                self.optimizer.step()
                loss_epoch += loss.item()
                print('epoch: [%d/%d], iter: [%d/%d], loss: [%.4f], lr: [%s]' % (
                    epoch, self.config.epoch, i, iter_num, loss.item(), self.config.lr))
                
                if self.config.visdom:
                    error = OrderedDict([('loss:', loss.item())])
                    self.visual.plot_current_errors(epoch, i / iter_num, error)
                
            if (epoch + 1) % self.config.epoch_show == 0:
                print('epoch: [%d/%d], epoch_loss: [%.4f], lr: [%s]' % (epoch, self.config.epoch, loss_epoch / iter_num, self.config.lr), file=self.log_output)

            if self.config.val and (epoch + 1) % self.config.epoch_val == 0:
                mae = self.validation()
                print('--- Best MAE: %.5f, Curr MAE: %.5f ---' % (best_mae, mae))
                print('--- Best MAE: %.5f, Curr MAE: %.5f ---' % (best_mae, mae), file=self.log_output)
                if best_mae > mae:
                    best_mae = mae
                    torch.save(self.net.state_dict(), '%s/models/best.pth' % self.config.save_fold)
            if (epoch + 1) % self.config.epoch_save == 0:
                torch.save(self.net.state_dict(), '%s/models/epoch_%d.pth' % (self.config.save_fold, epoch + 1))

        torch.save(self.net.state_dict(), '%s/models/final.pth' % self.config.save_fold)     
예제 #5
0
class Solver(object):
    def __init__(self, train_loader, test_dataset, config):
        self.train_loader = train_loader
        self.test_dataset = test_dataset
        self.config = config
        self.beta = 0.3  # for max F_beta metric
        # inference: choose the side map (see paper)
        self.select = [1, 2, 3, 6]
        # self.device = torch.device('cpu')
        self.mean = torch.Tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        self.std = torch.Tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
        
        self.update = config.update
        self.step = config.step
        #modified by hanqi 
        self.summary = TensorboardSummary("%s/logs/" % config.save_fold)
        self.writer = self.summary.create_summary()
        self.visual_save_fold = config.save_fold
        if self.config.cuda:
            cudnn.benchmark = True
            # self.device = torch.device('cuda:0')
        if config.visdom:
            self.visual = Viz_visdom("DSS", 1)
        self.build_model()
        if self.config.pre_trained: self.net.module.load_state_dict(torch.load(self.config.pre_trained))
        self.transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        self.t_transform = transforms.Compose([
            # transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: torch.round(x))
            # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        if config.mode == 'train':
            self.log_output = open("%s/logs/log.txt" % config.save_fold, 'w')
            
        else:
            self.net.module.load_state_dict(torch.load(self.config.model)["state_dict"])
            self.net.eval()
            # self.test_output = open("%s/test.txt" % config.test_fold, 'w')
            

    # print the network information and parameter numbers
    def print_network(self, model, name):
        num_params = 0
        for p in model.parameters():
            if p.requires_grad: num_params += p.numel()
        print(name)
        print(model)
        print("The number of parameters: {}".format(num_params))

    # build the network
    def build_model(self):
        self.net = torch.nn.DataParallel(build_model()).cuda()
        if self.config.mode == 'train': self.loss = Loss().cuda()
        self.net.train()
        self.net.apply(weights_init)
        if self.config.load == '': self.net.module.base.load_state_dict(torch.load(self.config.vgg))
        if self.config.load != '': self.net.module.load_state_dict(torch.load(self.config.load))
        self.optimizer = Adam(self.net.parameters(), self.config.lr)
        self.print_network(self.net, 'DSS')

    # update the learning rate
    def update_lr(self):
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = param_group['lr'] / 10.0

    # evaluate MAE (for test or validation phase)
    def eval_mae(self, y_pred, y):
        return torch.abs(y_pred - y).mean()

    # TODO: write a more efficient version
    # get precisions and recalls: threshold---divided [0, 1] to num values
    def eval_pr(self, y_pred, y, num):
        prec, recall = torch.zeros(num), torch.zeros(num)
        thlist = torch.linspace(0, 1 - 1e-10, num)
        for i in range(num):
            y_temp = (y_pred >= thlist[i]).float()
            tp = (y_temp * y).sum()
            prec[i], recall[i] = tp / (y_temp.sum() + 1e-20), tp / y.sum()
        return prec, recall



    # test phase: using origin image size, evaluate MAE and max F_beta metrics
    def test(self, num, use_crf=False, epoch=None):
        if use_crf: from tools.crf_process import crf
        avg_mae, img_num = 0.0, 0.0
        avg_prec, avg_recall = torch.zeros(num), torch.zeros(num)
        with torch.no_grad():
            for i, (img, labels, bg, fg, name) in enumerate(self.test_dataset):
                images = self.transform(img).unsqueeze(0)
                labels = self.t_transform(labels).unsqueeze(0)
                shape = labels.size()[2:]
                images = images.cuda()
                prob_pred = self.net(images, mode='test')
                bg_pred = torch.mean(torch.cat([prob_pred[i+7] for i in self.select], dim=1), dim=1, keepdim=True)
                bg_pred = (bg_pred > 0.5).float()
                prob_pred = torch.mean(torch.cat([prob_pred[i] for i in self.select], dim=1), dim=1, keepdim=True)
                
                prob_pred = F.interpolate(prob_pred, size=shape, mode='bilinear', align_corners=True).cpu().data
                bg_pred = F.interpolate(bg_pred, size=shape, mode='nearest').cpu().data.numpy()
                fork_bg, fork_fg = Bwdist(bg_pred)
                if use_crf:
                    prob_pred = crf(img, prob_pred.numpy(), to_tensor=True)
                if not os.path.exists('{}/visualize_pred{}/'.format(self.visual_save_fold, epoch)):
                    os.mkdir('{}/visualize_pred{}/'.format(self.visual_save_fold, epoch))
                img_save = prob_pred.numpy()
                img_save = img_save.reshape(-1, img_save.shape[2], img_save.shape[3]).transpose(1,2,0) * 255
                cv2.imwrite('{}/visualize_pred{}/{}'.format(self.visual_save_fold, epoch, name), img_save.astype(np.uint8))
                # print('save visualize_pred{}/{} done.'.format(name, epoch))
                if not os.path.exists('{}/visualize_bg{}/'.format(self.visual_save_fold, epoch)):
                    os.mkdir('{}/visualize_bg{}/'.format(self.visual_save_fold, epoch))
                img_save = fork_bg
                img_save = img_save.reshape(-1, img_save.shape[2], img_save.shape[3]).transpose(1,2,0) * 255
                cv2.imwrite('{}/visualize_bg{}/{}'.format(self.visual_save_fold, epoch, name), img_save.astype(np.uint8))
                # print('save visualize_bg{}/{} done.'.format(name, epoch))
                if not os.path.exists('{}/visualize_fg{}/'.format(self.visual_save_fold, epoch)):
                    os.mkdir('{}/visualize_fg{}/'.format(self.visual_save_fold, epoch))
                img_save = fork_fg
                img_save = img_save.reshape(-1, img_save.shape[2], img_save.shape[3]).transpose(1,2,0) * 255
                cv2.imwrite('{}/visualize_fg{}/{}'.format(self.visual_save_fold, epoch, name), img_save.astype(np.uint8))
                # print('save visualize_bg{}/{} done.'.format(name, epoch))
                mae = self.eval_mae(prob_pred, labels)
                if mae == mae:
                    avg_mae += mae
                    img_num += 1.0
                    # prec, recall = self.eval_pr(prob_pred, labels, num)
                    # avg_prec, avg_recall = avg_prec + prec, avg_recall + recall
        avg_mae = avg_mae / img_num
        # avg_mae, avg_prec, avg_recall = avg_mae / img_num, avg_prec / img_num, avg_recall / img_num
        # score = (1 + self.beta ** 2) * avg_prec * avg_recall / (self.beta ** 2 * avg_prec + avg_recall)
        # score[score != score] = 0  # delete the nan
        # print('average mae: %.4f, max fmeasure: %.4f' % (avg_mae, score.max()))
        print('average mae: %.4f' % (avg_mae))
        # print('average mae: %.4f, max fmeasure: %.4f' % (avg_mae, score.max()), file=self.test_output)
        return avg_mae, 1.0 #score.max()

    # training phase
    def train(self):
        start_epoch = 0
        best_mae = 1.0 if self.config.val else None
        if self.config.resume is not None:
            if not os.path.isfile(self.config.resume):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
            checkpoint = torch.load(self.config.resume)
            start_epoch = checkpoint['epoch']
            if self.config.cuda:
                self.net.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.net.load_state_dict(checkpoint['state_dict'])
            
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            best_mae = checkpoint['best_mae']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(self.config.resume, checkpoint['epoch']))

        iter_num = len(self.train_loader.dataset) // self.config.batch_size
        
        for epoch in range(start_epoch, self.config.epoch):
            # if str(epoch + 1) in self.step:
            #     self.update_lr()
            loss_epoch = 0
            tbar = tqdm(self.train_loader)
            
            for i, data_batch in enumerate(tbar):
                if (i + 1) > iter_num: break
                self.net.zero_grad()
                x, y, bg, fg= data_batch
                x, y, bg, fg = x.cuda(), y.cuda(), bg.cuda(), fg.cuda()
                y_pred = self.net(x, bg=bg, fg=fg)
                loss = self.loss(y_pred, y)
                loss.backward()
                utils.clip_grad_norm_(self.net.parameters(), self.config.clip_gradient)
                # utils.clip_grad_norm(self.loss.parameters(), self.config.clip_gradient)
                # if (i+1) % self.update == 0 or (i+1) == iter_num:
                self.optimizer.step()
                
                loss_epoch += loss.item()
                self.writer.add_scalar('train/total_loss_iter', loss.item(), epoch * iter_num  + i)
                tbar.set_description('epoch:[%d/%d],loss:[%.4f]' % (
                    epoch, self.config.epoch, loss.item()))
                # print('epoch: [%d/%d], iter: [%d/%d], loss: [%.4f]' % (
                #     epoch, self.config.epoch, i, iter_num, loss.item()))
                if self.config.visdom:
                    error = OrderedDict([('loss:', loss.item())])
                    self.visual.plot_current_errors(epoch, i / iter_num, error)
            self.writer.add_scalar('train/total_loss_epoch', loss_epoch / iter_num, epoch)
            if (epoch + 1) % self.config.epoch_show == 0:
                print('epoch: [%d/%d], epoch_loss: [%.4f]' % (epoch, self.config.epoch, loss_epoch / iter_num),
                      file=self.log_output)
                if self.config.visdom:
                    avg_err = OrderedDict([('avg_loss', loss_epoch / iter_num)])
                    self.visual.plot_current_errors(epoch, i / iter_num, avg_err, 1)
                    y_show = torch.mean(torch.cat([y_pred[i] for i in self.select], dim=1), dim=1, keepdim=True)
                    img = OrderedDict([('origin', x.cpu()[0] * self.std + self.mean), ('label', y.cpu()[0][0]),
                                       ('pred_label', y_show.cpu().data[0][0])])
                    self.visual.plot_current_img(img)
            if self.config.val and (epoch + 1) % self.config.epoch_val == 0:
                mae, fscore = self.test(100, epoch=epoch+1)
                self.writer.add_scalar('test/MAE', mae, epoch)
                self.writer.add_scalar('test/F-Score', fscore, epoch)
                print('--- Best MAE: %.2f, Curr MAE: %.2f ---' % (best_mae, mae))
                print('--- Best MAE: %.2f, Curr MAE: %.2f ---' % (best_mae, mae), file=self.log_output)
                if best_mae > mae:
                    best_mae = mae
                    torch.save({
                        'epoch': epoch + 1,
                        'state_dict': self.net.module.state_dict(),
                        'optimizer': self.optimizer.state_dict(),
                        'best_mae': mae
                    }, '%s/models/best.pth' % self.config.save_fold)
                    # torch.save(self.net.state_dict(), '%s/models/best.pth' % self.config.save_fold)
            # if (epoch + 1) % self.config.epoch_save == 0:
            #     torch.save(self.net.module.state_dict(), '%s/models/epoch_%d.pth' % (self.config.save_fold, epoch + 1))
        torch.save(self.net.module.state_dict(), '%s/models/final.pth' % self.config.save_fold)
예제 #6
0
class Solver(object):
    def __init__(self, train_loader, val_loader, test_dataset, config,mode):
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_dataset = test_dataset
        self.config = config
        self.beta = 0.3
        self.select = [1, 2, 3, 6]
        self.device = torch.device('cuda:0')
        self.mean = torch.Tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        self.std = torch.Tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
        self.mode = mode
        if self.config.mode == "train":
            self.lossfile = open("%s/logs/loss.txt" % config.save_fold, 'w')
            self.maefile = open("%s/logs/mae.txt" % config.save_fold, 'w')
        if self.config.cuda:
            cudnn.benchmark = True
            self.device = torch.device('cuda:0')
        if config.visdom:
            self.visual = Viz_visdom("DSS", 1)
        self.build_model()
        if self.config.pre_trained: self.net.load_state_dict(torch.load(self.config.pre_trained))
        if config.mode == 'train':
            self.log_output = open("%s/logs/log.txt" % config.save_fold, 'w')
        else:
            self.net.load_state_dict(torch.load(self.config.model))
            self.net.eval()
            self.test_output = open("%s/test.txt" % config.test_fold, 'w')
            self.transform = transforms.Compose([
                transforms.Resize((256, 256)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])

    def print_network(self, model, name):
        num_params = 0
        for p in model.parameters():
            if p.requires_grad: num_params += p.numel()
        print(name)
        print(model)
        print("The number of parameters: {}".format(num_params))

    def build_model(self):
        if (self.mode == 1):
            self.net = build_model().to(self.device)
        else:
            self.net = build_modelv2().to(self.device)
        if self.config.mode == 'train': self.loss = Loss().to(self.device)
        self.net.train()
        self.net.apply(weights_init)
        if self.config.load == '': self.net.base.load_state_dict(torch.load(self.config.vgg))
        if self.config.load != '': self.net.load_state_dict(torch.load(self.config.load))
        self.optimizer = Adam(self.net.parameters(), self.config.lr)
        self.print_network(self.net, 'DSS')

    # update the learning rate
    def update_lr(self, lr):
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

    # evaluate MAE (for test or validation phase)
    def eval_mae(self, y_pred, y):
        return torch.abs(y_pred - y).mean()

    def eval_pr(self, y_pred, y, num):
        prec, recall = torch.zeros(num), torch.zeros(num)
        thlist = torch.linspace(0, 1 - 1e-10, num)
        for i in range(num):
            y_temp = (y_pred >= thlist[i]).float()
            tp = (y_temp * y).sum()
            prec[i], recall[i] = tp / (y_temp.sum() + 1e-20), tp / y.sum()
        return prec, recall

    # validation: using resize image, and only evaluate the MAE metric
    def validation(self):
        avg_mae = 0.0
        self.net.eval()
        with torch.no_grad():
            for i, data_batch in enumerate(self.val_loader):
                images, y1, y2, y3, y4, labels = data_batch
                images, labels = images.to(self.device),  labels.to(self.device)
                y1 = y1.to(self.device)
                y2 = y2.to(self.device)
                y3 = y3.to(self.device)
                y4 = y4.to(self.device)
                prob_pred = self.net(images, y1, y2, y3, y4)
                avg_mae += self.eval_mae(prob_pred, labels).item()
        self.net.train()
        return avg_mae / len(self.val_loader)

    # test phase: using origin image size, evaluate MAE and max F_beta metrics
    def test(self, num, use_crf=False):
        if use_crf: from tools.crf_process import crf
        avg_mae, img_num = 0.0, len(self.test_dataset)
        avg_prec, avg_recall = torch.zeros(num), torch.zeros(num)
        with torch.no_grad():
            for i, (img, y1,y2,y3,y4, labels) in enumerate(self.test_dataset):
                img.show()
                images = self.transform(img).unsqueeze(0)
                y1 = self.transform(y1).unsqueeze(0)
                y2= self.transform(y2).unsqueeze(0)
                y3 = self.transform(y3).unsqueeze(0)
                y4 = self.transform(y4).unsqueeze(0)
                if(images.shape != torch.Size([1,3,256,256])):
                    continue
                labels = labels.unsqueeze(0)
                shape = labels.size()[2:]
                images = images.to(self.device)
                y1 = y1.to(self.device)
                y2 = y2.to(self.device)
                y3 = y3.to(self.device)
                y4 = y4.to(self.device)
                prob_pred = self.net(images, y1, y2, y3, y4)
                if (self.mode == 1):
                    prob_pred = torch.mean(torch.cat([prob_pred[i] for i in self.select], dim=1), dim=1, keepdim=True)
                    prob_pred = F.interpolate(prob_pred, size=shape, mode='bilinear', align_corners=True).cpu().data

                else:
                    prob_pred = F.interpolate(prob_pred, size=shape, mode='bilinear', align_corners=True).cpu().data

                if use_crf:
                    prob_pred = crf(img, prob_pred.numpy(), to_tensor=True)
                mae = self.eval_mae(prob_pred, labels)
                prec, recall = self.eval_pr(prob_pred, labels, num)
                print("[%d] mae: %.4f" % (i, mae))
                print("[%d] mae: %.4f" % (i, mae), file=self.test_output)

                #********************To present hard cases**********************************************
                """
                if (mae>0.2):
                    img.show()
                    ss = prob_pred[0][0].cpu().numpy()
                    ss = 256 * ss
                    ims1 = Image.fromarray(ss)
                    ims1.show()
                """
                avg_mae += mae
                avg_prec, avg_recall = avg_prec + prec, avg_recall + recall
        avg_mae, avg_prec, avg_recall = avg_mae / img_num, avg_prec / img_num, avg_recall / img_num
        score = (1 + self.beta ** 2) * avg_prec * avg_recall / (self.beta ** 2 * avg_prec + avg_recall)
        score[score != score] = 0
        print('average mae: %.4f, max fmeasure: %.4f' % (avg_mae, score.max()))
        print('average mae: %.4f, max fmeasure: %.4f' % (avg_mae, score.max()), file=self.test_output)

    # training phase
    def train(self):
        iter_num = len(self.train_loader.dataset) // self.config.batch_size
        best_mae = 1.0 if self.config.val else None
        self.lossfile.write("epoch\tavg_loss\n")
        self.maefile.write("epoch\tavg_mae\n")
        for epoch in range(self.config.epoch):
            loss_epoch = 0
            #learning rate decay.
            if epoch ==30:
                lr = self.config.lr
                self.update_lr(lr)
            mae = 0
            for i, data_batch in enumerate(self.train_loader):
                if (i + 1) > iter_num: break
                self.net.zero_grad()
                x, y1,y2,y3,y4, y = data_batch
                x, y1,y2,y3,y4, y = x.to(self.device), y1.to(self.device),y2.to(self.device),\
                                     y3.to(self.device),y4.to(self.device), y.to(self.device)

                y_pred = self.net(x, y1, y2 ,y3 ,y4)
                loss = self.loss(y_pred, y)
                loss.backward()
                utils.clip_grad_norm_(self.net.parameters(), self.config.clip_gradient)
                # utils.clip_grad_norm(self.loss.parameters(), self.config.clip_gradient)
                self.optimizer.step()
                loss_epoch += float(loss.item())
                tmp_mae = self.eval_mae(y_pred,y).item()
                mae += tmp_mae
                print('epoch: [%d/%d], iter: [%d/%d], loss: [%.4f], mae: [%.4f]' % (
                    epoch, self.config.epoch, i, iter_num, loss.item(),tmp_mae))
                if self.config.visdom:
                    error = OrderedDict([('loss:', loss.item())])
                    self.visual.plot_current_errors(epoch, i / iter_num, error)
            avg_loss = loss_epoch / iter_num
            self.lossfile.write("%d\t%.4f\n"%(epoch,avg_loss))
            avg_mae = mae / iter_num
            self.maefile.write("%d\t%.4f\n"%(epoch,avg_mae))
            if (epoch + 1) % self.config.epoch_show == 0:
                print('epoch: [%d/%d], epoch_loss: [%.4f]' % (epoch, self.config.epoch, loss_epoch / iter_num),
                      file=self.log_output)
                if self.config.visdom:
                    avg_err = OrderedDict([('avg_loss', loss_epoch / iter_num)])
                    self.visual.plot_current_errors(epoch, i / iter_num, avg_err, 1)
                    y_show = torch.mean(torch.cat([y_pred[i] for i in self.select], dim=1), dim=1, keepdim=True)
                    img = OrderedDict([('origin', x.cpu()[0] * self.std + self.mean), ('label', y.cpu()[0][0]),
                                       ('pred_label', y_show.cpu().data[0][0])])
                    self.visual.plot_current_img(img)

            if self.config.val and (epoch + 1) % self.config.epoch_val == 0:
                mae = self.validation()
                print('--- Best MAE: %.2f, Curr MAE: %.2f ---' % (best_mae, mae))
                print('--- Best MAE: %.2f, Curr MAE: %.2f ---' % (best_mae, mae), file=self.log_output)
                if best_mae > mae:
                    best_mae = mae
                    torch.save(self.net.state_dict(), '%s/models/mybest.pth' % self.config.save_fold)
            if (epoch + 1) % self.config.epoch_save == 0:
                torch.save(self.net.state_dict(), '%s/models/epoch_%d.pth' % (self.config.save_fold, epoch + 1))

        torch.save(self.net.state_dict(), '%s/models/final.pth' % self.config.save_fold)
예제 #7
0
class Solver(object):
    def __init__(self, train_loader, target_loader, val_loader, test_dataset,
                 config):
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_dataset = test_dataset
        self.target_loader = target_loader
        self.config = config
        self.beta = math.sqrt(0.3)  # for max F_beta metric
        # inference: choose the side map (see paper)
        self.select = [1, 2, 3, 6]
        self.device = torch.device('cpu')
        self.mean = torch.Tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        self.std = torch.Tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
        self.TENSORBOARD_LOGDIR = f'{config.save_fold}/tensorboards'
        self.TENSORBOARD_VIZRATE = 100
        if self.config.cuda:
            cudnn.benchmark = True
            self.device = torch.device('cuda:0')
        if config.visdom:
            self.visual = Viz_visdom("DSS", 1)
        self.build_model()
        if self.config.pre_trained:
            self.net.load_state_dict(torch.load(self.config.pre_trained))
        if config.mode == 'train':
            self.log_output = open("%s/logs/log.txt" % config.save_fold, 'w')
            self.val_output = open("%s/logs/val.txt" % config.save_fold, 'w')
        else:
            self.net.load_state_dict(torch.load(self.config.model))
            self.net.eval()
            self.test_output = open("%s/test.txt" % config.test_fold, 'w')
            self.test_maeid = open("%s/mae_id.txt" % config.test_fold, 'w')
            self.test_outmap = config.test_map_fold
            self.transform = transforms.Compose([
                transforms.Resize((256, 256)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
            ])

    # print the network information and parameter numbers
    def print_network(self, model, name):
        num_params = 0
        for p in model.parameters():
            if p.requires_grad: num_params += p.numel()
        print(name)
        print(model)
        print("The number of parameters: {}".format(num_params))

    # # build the network
    # def build_model(self):
    #     self.net = build_model().to(self.device)
    #     if self.config.mode == 'train': self.loss = Loss().to(self.device)
    #     self.net.train()
    #     self.net.apply(weights_init)
    #     if self.config.load == '': self.net.base.load_state_dict(torch.load(self.config.vgg))
    #     if self.config.load != '': self.net.load_state_dict(torch.load(self.config.load))
    #     self.optimizer = Adam(self.net.parameters(), self.config.lr)
    #     self.print_network(self.net, 'DSS')

    # # build the network --new
    def build_model(self):
        if self.config.mode == 'train':
            self.loss = Loss().to(self.device)
            self.l2loss = nn.MSELoss().to(self.device)
            self.iouloss = IoULoss().to(self.device)
        self.net = build_model().to(self.device)
        self.net.train()
        self.net.apply(weights_init)
        if self.config.load == '':
            self.net.base.load_state_dict(torch.load(self.config.vgg))
        if self.config.load != '':
            self.net.load_state_dict(torch.load(self.config.load))
        self.optimizer = Adam(self.net.parameters(), self.config.lr)

        self.net2 = build_model().to(self.device)
        self.net2.train()
        self.net2.apply(weights_init)
        if self.config.load == '':
            self.net2.base.load_state_dict(torch.load(self.config.vgg))
        if self.config.load != '':
            self.net2.load_state_dict(torch.load(self.config.load))
        self.optimizer2 = Adam(self.net2.parameters(), self.config.lr)
        # self.print_network(self.net, 'DSS')

    # update the learning rate
    def update_lr(self, lr):
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

    # evaluate MAE (for test or validation phase)
    def eval_mae(self, y_pred, y):
        return torch.abs(y_pred - y).mean()

    # TODO: write a more efficient version
    # get precisions and recalls: threshold---divided [0, 1] to num values
    def eval_pr(self, y_pred, y, num):
        prec, recall = torch.zeros(num), torch.zeros(num)
        thlist = torch.linspace(0, 1 - 1e-10, num)
        for i in range(num):
            y_temp = (y_pred >= thlist[i]).float()
            tp = (y_temp * y).sum()
            # prec[i], recall[i] = tp / (y_temp.sum() + 1e-20), tp / y.sum()
            prec[i], recall[i] = tp / (y_temp.sum() + 1e-20), tp / (y.sum() +
                                                                    1e-20)
        return prec, recall

    # validation: using resize image, and only evaluate the MAE metric
    def validation(self):
        avg_mae, avg_loss = 0.0, 0.0
        self.net.eval()
        with torch.no_grad():
            for i, data_batch in enumerate(self.val_loader):
                images, labels = data_batch
                shape = labels.size()[2:]
                images, labels = images.to(self.device), labels.to(self.device)
                _, prob_pred = self.net(images)
                # for side_num in range(len(prob_pred)):
                #         tmp = torch.sigmoid(prob_pred[side_num])[0]
                #         tmp = tmp.cpu().data
                #         img = ToPILImage()(tmp)
                #         img.save(self.config.val_fold_sub + '/' + self.val_loader.dataset.label_path[i][36:-4] +'_side_' + str(side_num) + '.png')
                # prob_pred1 = torch.mean(torch.cat([prob_pred[i] for i in self.select], dim=1), dim=1, keepdim=True)
                # prob_pred1 = F.interpolate(prob_pred, size=shape, mode='bilinear', align_corners=True)
                # prob_pred2 = torch.mean(torch.cat([torch.sigmoid(prob_pred[i]) for i in self.select], dim=1), dim=1, keepdim=True)
                # prob_pred2 = F.interpolate(prob_pred2, size=shape, mode='bilinear', align_corners=True)
                prob_pred2 = F.interpolate(torch.sigmoid(prob_pred),
                                           size=shape,
                                           mode='bilinear',
                                           align_corners=True)
                # avg_loss += self.loss(prob_pred2, labels).item()
                avg_mae += self.eval_mae(prob_pred2, labels).item()
        self.net.train()
        return avg_mae / len(self.val_loader), avg_loss / len(self.val_loader)

    # test phase: using origin image size, evaluate MAE and max F_beta metrics
    def test(self, num, use_crf=False):
        if use_crf: from tools.crf_process import crf
        dic = {}
        avg_mae, img_num = 0.0, len(self.test_dataset)
        avg_prec, avg_recall = torch.zeros(num), torch.zeros(num)
        with torch.no_grad():
            for i, (img, labels) in enumerate(self.test_dataset):
                images = self.transform(img).unsqueeze(0)
                labels = labels.unsqueeze(0)
                shape = labels.size()[2:]
                images = images.to(self.device)
                _, prob_pred = self.net(images)
                # prob_pred = torch.mean(torch.cat([torch.sigmoid(prob_pred[i]) for i in self.select], dim=1), dim=1, keepdim=True)
                prob_pred = F.interpolate(torch.sigmoid(prob_pred),
                                          size=shape,
                                          mode='bilinear',
                                          align_corners=True).cpu().data
                # prob_pred = F.interpolate(prob_pred, size=shape, mode='bilinear', align_corners=True).cpu().data
                if use_crf:
                    prob_pred = crf(img, prob_pred.numpy(), to_tensor=True)
                mae = self.eval_mae(prob_pred, labels)
                # dic.update({self.test_dataset.label_path[i][self.config.test_map_save_pos:-4] : mae})
                prec, recall = self.eval_pr(prob_pred, labels, num)
                tmp = prob_pred[0]
                imgpred = ToPILImage()(tmp)
                imgpred.save(self.test_outmap + '/' +
                             self.test_dataset.label_path[i]
                             [self.config.test_map_save_pos:])
                print("[%d] mae: %.4f" % (i, mae))
                print("[%d] mae: %.4f" % (i, mae), file=self.test_output)
                avg_mae += mae
                avg_prec, avg_recall = avg_prec + prec, avg_recall + recall
        avg_mae, avg_prec, avg_recall = avg_mae / img_num, avg_prec / img_num, avg_recall / img_num
        score = (1 + self.beta**2) * avg_prec * avg_recall / (
            self.beta**2 * avg_prec + avg_recall)
        score[score != score] = 0  # delete the nan
        print('average mae: %.4f, max fmeasure: %.4f' % (avg_mae, score.max()))
        print('average mae: %.4f, max fmeasure: %.4f' % (avg_mae, score.max()),
              file=self.test_output)

        # dic_sorted = sorted(dic.items(), key = lambda kv:(kv[1], kv[0]),reverse=True)
        # # file1 = open('/data1/liumengmeng/CG4_id_mae/HKU-IS.txt','w')
        # for i in range(int(len(dic_sorted)*0.1)):
        #     print(dic_sorted[i][0] ,file=self.test_maeid)

    def test_bg(self):
        dic = {}
        with torch.no_grad():
            for i, img in enumerate(self.test_dataset):
                print(self.test_dataset.image_path[i]
                      [self.config.test_map_save_pos:-4])
                try:
                    images = self.transform(img).unsqueeze(0)
                    images = images.to(self.device)
                    prob_pred = self.net(images)
                    prob_pred = torch.mean(torch.cat(
                        [torch.sigmoid(prob_pred[i]) for i in self.select],
                        dim=1),
                                           dim=1,
                                           keepdim=True)
                    prob_pred = prob_pred.cpu().data
                    tmp = prob_pred[0]
                    probarray = tmp.numpy()
                    num_1 = len(np.argwhere(probarray > 0.5))
                    ratio = num_1 / (tmp.shape[1] * tmp.shape[2])
                    dic.update({
                        self.test_dataset.image_path[i][self.config.test_map_save_pos:-4]:
                        ratio
                    })
                    print(ratio)
                except TypeError as tycode:
                    print(self.test_dataset.image_path[i]
                          [self.config.test_map_save_pos:-4],
                          file=filebad_id)

        dic_sorted = sorted(dic.items(), key=lambda kv: (kv[1], kv[0]))
        for i in dic_sorted:
            print(f'{i[0]} : {i[1]}', file=self.test_output)
            print(i[0], file=self.test_bg_id)

    def train(self):
        num_classes = 1
        viz_tensorboard = os.path.exists(self.TENSORBOARD_LOGDIR)
        if viz_tensorboard:
            writer = SummaryWriter(log_dir=self.TENSORBOARD_LOGDIR)

        # # DISCRIMINATOR NETWORK
        # d_main = get_fc_discriminator(num_classes=num_classes)
        # d_main.train()
        # d_main.to(self.device)
        # # # OPTIMIZERS
        # # # discriminators' optimizers
        # optimizer_d_main = optim.Adam(d_main.parameters(), lr=self.config.lr_d,
        #                             betas=(0.9, 0.99))
        # # LABELS for adversarial training-------------------------------------------------------
        # source_label = 0
        # target_label = 1
        trainloader_iter = enumerate(self.train_loader)
        targetloader_iter = enumerate(self.target_loader)
        best_mae = 1.0 if self.config.val else None

        for i_iter in tqdm(range(self.config.early_stop)):

            # if i_iter >= 3000:
            #     self.update_lr(1e-5)

            # # reset optimizers
            self.optimizer.zero_grad()
            self.optimizer2.zero_grad()
            # optimizer_d_main.zero_grad()

            # # adapt LR if needed
            # adjust_learning_rate(self.optimizer, i_iter, cfg)
            # adjust_learning_rate_discriminator(optimizer_d_aux, i_iter, cfg)
            # adjust_learning_rate_discriminator(optimizer_d_main, i_iter, cfg)

            # # UDA Training--------------------------------------------------------------------------
            # # only train segnet. Don't accumulate grads in disciminators
            # for param in d_main.parameters():
            #     param.requires_grad = False

            # # train on source with seg loss
            # _, batch = trainloader_iter.__next__()
            # imgs_src, labels_src = batch
            # imgs_src, labels_src = imgs_src.to(self.device), labels_src.to(self.device)
            # pred_src_main = self.net(imgs_src)
            # # loss_seg_src = self.loss(pred_src_main[0], labels_src) #side output 1
            # loss_seg_src = self.loss(pred_src_main, labels_src) #side output 1 - 6 with fusion
            # loss = loss_seg_src
            # loss.backward()
            # utils.clip_grad_norm_(self.net.parameters(), self.config.clip_gradient)

            # # train on target with seg loss
            # _, batch1 = targetloader_iter.__next__()
            # imgs_trg, labels_trg = batch1
            # imgs_trg, labels_trg = imgs_trg.to(self.device), labels_trg.to(self.device)
            # pred_trg = self.net(imgs_trg)
            # loss_seg_trg = self.loss(pred_trg[5], labels_trg) # side output 6
            # loss = loss_seg_trg
            # loss.backward()
            # utils.clip_grad_norm_(self.net.parameters(), self.config.clip_gradient)

            #----train on source branch------------------------------
            _, batch = trainloader_iter.__next__()
            imgs_src, labels_src = batch
            imgs_src, labels_src = imgs_src.to(self.device), labels_src.to(
                self.device)
            smap, pred_src = self.net(imgs_src)
            stmap, _ = self.net2(imgs_src)
            loss_seg_src = self.loss(pred_src, labels_src)  #sigmoid BCE loss
            loss_fc_src = self.l2loss(smap,
                                      stmap)  #L2 loss -> self attention maps
            loss = loss_seg_src + loss_fc_src
            loss.backward()
            # utils.clip_grad_norm_(self.net.parameters(), self.config.clip_gradient)

            #----train on target branch-----------------------------
            _, batch1 = targetloader_iter.__next__()
            imgs_trg, labels_trg = batch1
            imgs_trg, labels_trg = imgs_trg.to(self.device), labels_trg.to(
                self.device)
            tmap, pred_trg = self.net2(imgs_trg)
            tsmap, _ = self.net(imgs_trg)
            loss_ctr_trg = self.iouloss(pred_trg[-1],
                                        labels_trg)  # IoU loss: dns6
            loss_fc_trg = self.l2loss(tmap,
                                      tsmap)  #L2 loss -> self attention maps
            loss = loss_ctr_trg + loss_fc_trg
            loss.backward()

            utils.clip_grad_norm_(self.net.parameters(),
                                  self.config.clip_gradient)
            utils.clip_grad_norm_(self.net2.parameters(),
                                  self.config.clip_gradient)

            current_losses = {
                'loss_seg_src': loss_seg_src,
                'loss_fc_src': loss_fc_src,
                'loss_ctr_trg': loss_ctr_trg,
                'loss_fc_trg': loss_fc_trg
            }

            #-------add ADVENT--------------------------------------------------------------------------------
            if self.config.add_adv:
                # adversarial training ot fool the discriminator
                _, batch = targetloader_iter.__next__()
                images = batch
                images = images.to(self.device)
                pred_trg_main = self.net(images)

                # d_out_main = d_main(torch.sigmoid(pred_trg_main))
                # loss_adv_trg_main = bce_loss(d_out_main, source_label)
                # loss_adv_trg = self.config.LAMBDA_ADV_MAIN * loss_adv_trg_main
                # loss = loss_adv_trg
                # loss.backward()

                # d_out_main = d_main(prob_2_entropy(pred_trg_main[0]))
                d_out_main = d_main(torch.sigmoid(pred_trg_main[0]))
                loss_adv_trg_main = bce_loss(d_out_main, source_label)
                loss_adv_trg = self.config.LAMBDA_ADV_MAIN * loss_adv_trg_main
                for i in range(len(pred_trg_main) - 1):
                    # d_out_main = d_main(prob_2_entropy(pred_trg_main[i+1]))
                    d_out_main = d_main(torch.sigmoid(pred_trg_main[i + 1]))
                    loss_adv_trg_main = bce_loss(d_out_main, source_label)
                    loss_adv_trg += self.config.LAMBDA_ADV_MAIN * loss_adv_trg_main
                loss = loss_adv_trg
                loss.backward()

                # Train discriminator networks--------------------
                # enable training mode on discriminator networks
                for param in d_main.parameters():
                    param.requires_grad = True

                # # train with source
                # pred_src_main = pred_src_main.detach()
                # d_out_main = d_main(torch.sigmoid(pred_src_main))
                # loss_d_main = bce_loss(d_out_main, source_label)
                # loss_d_src = loss_d_main / 2
                # loss_d = loss_d_src
                # loss_d.backward()

                # # train with target
                # pred_trg_main = pred_trg_main.detach()
                # d_out_main = d_main(torch.sigmoid(pred_trg_main))
                # loss_d_main = bce_loss(d_out_main, target_label)
                # loss_d_trg = loss_d_main / 2
                # loss_d = loss_d_trg
                # loss_d.backward()

                # train with source
                pred_src_main[0] = pred_src_main[0].detach()
                # d_out_main = d_main(prob_2_entropy(pred_src_main[0]))
                d_out_main = d_main(torch.sigmoid(pred_src_main[0]))
                loss_d_main = bce_loss(d_out_main, source_label)
                loss_d_src = loss_d_main / 2
                for i in range(len(pred_src_main) - 1):
                    pred_src_main[i + 1] = pred_src_main[i + 1].detach()
                    # d_out_main = d_main(prob_2_entropy(pred_src_main[i+1]))
                    d_out_main = d_main(torch.sigmoid(pred_src_main[i + 1]))
                    loss_d_main = bce_loss(d_out_main, source_label)
                    loss_d_src += loss_d_main / 2
                loss_d = loss_d_src
                loss_d.backward()

                # train with target
                pred_trg_main[0] = pred_trg_main[0].detach()
                # d_out_main = d_main(prob_2_entropy(pred_trg_main[0]))
                d_out_main = d_main(torch.sigmoid(pred_trg_main[0]))
                loss_d_main = bce_loss(d_out_main, target_label)
                loss_d_trg = loss_d_main / 2
                for i in range(len(pred_trg_main) - 1):
                    pred_trg_main[i + 1] = pred_trg_main[i + 1].detach()
                    # d_out_main = d_main(prob_2_entropy(pred_trg_main[i+1]))
                    d_out_main = d_main(torch.sigmoid(pred_trg_main[i + 1]))
                    loss_d_main = bce_loss(d_out_main, target_label)
                    loss_d_trg += loss_d_main / 2
                loss_d = loss_d_trg
                loss_d.backward()

                current_losses = {
                    'loss_seg_src': loss_seg_src,
                    'loss_adv_trg': loss_adv_trg,
                    'loss_d_src': loss_d_src,
                    'loss_d_trg': loss_d_trg
                }

            # # optimizer.step()------------------------------------------------------------------------------
            self.optimizer.step()
            self.optimizer2.step()
            # optimizer_d_main.step()

            # current_losses = {
            #                 'loss_seg_src': loss_seg_src}
            #                 # 'loss_adv_trg': loss_adv_trg,
            #                 # 'loss_d_src': loss_d_src,
            #                 # 'loss_d_trg': loss_d_trg}
            print_losses(current_losses, i_iter, self.log_output)

            if self.config.val and (i_iter + 1) % self.config.iter_val == 0:
                # val = i_iter + 1
                # os.mkdir("%s/val-%d" % (self.config.val_fold, val))
                # self.config.val_fold_sub = "%s/val-%d" % (self.config.val_fold, val)
                mae, loss_val = self.validation()
                log_vals_tensorboard(writer, best_mae, mae, loss_val,
                                     i_iter + 1)
                tqdm.write('%d:--- Best MAE: %.4f, Curr MAE: %.4f ---' %
                           ((i_iter + 1), best_mae, mae))
                print('  %d:--- Best MAE: %.4f, Curr MAE: %.4f ---' %
                      ((i_iter + 1), best_mae, mae),
                      file=self.log_output)
                print('  %d:--- Best MAE: %.4f, Curr MAE: %.4f ---' %
                      ((i_iter + 1), best_mae, mae),
                      file=self.val_output)
                if best_mae > mae:
                    best_mae = mae
                    torch.save(self.net.state_dict(),
                               '%s/models/best.pth' % self.config.save_fold)

            if (i_iter + 1) % self.config.iter_save == 0 and i_iter != 0:
                # tqdm.write('taking snapshot ...')
                # torch.save(self.net.state_dict(), '%s/models/iter_%d.pth' % (self.config.save_fold, i_iter + 1))
                # torch.save(d_main.state_dict(), '%s/models/iter_Discriminator_%d.pth' % (self.config.save_fold, i_iter + 1))
                if i_iter >= self.config.early_stop - 1:
                    break

            sys.stdout.flush()

            if viz_tensorboard:
                log_losses_tensorboard(writer, current_losses, i_iter)
                # if i_iter % self.TENSORBOARD_VIZRATE == self.TENSORBOARD_VIZRATE - 1:
                #     draw_in_tensorboard(writer, images, i_iter, pred_trg_main, num_classes, 'T')
                #     draw_in_tensorboard(writer, images_source, i_iter, pred_src_main, num_classes, 'S')

        # torch.save(self.net.state_dict(), '%s/models/final.pth' % self.config.save_fold)

    def train_old(self):
        print(len(self.train_loader.dataset))
        iter_num = len(self.train_loader.dataset) // self.config.batch_size
        best_mae = 1.0 if self.config.val else None
        for epoch in range(self.config.epoch):
            loss_epoch = 0
            for i, data_batch in enumerate(self.train_loader):
                if (i + 1) > iter_num: break
                self.net.zero_grad()
                x, y = data_batch
                x, y = x.to(self.device), y.to(self.device)
                y_pred = self.net(x)
                loss = self.loss(y_pred, y)
                loss.backward()
                utils.clip_grad_norm_(self.net.parameters(),
                                      self.config.clip_gradient)
                # utils.clip_grad_norm(self.loss.parameters(), self.config.clip_gradient)
                self.optimizer.step()
                loss_epoch += loss.item()
                print('epoch: [%d/%d], iter: [%d/%d], loss: [%.4f]' %
                      (epoch, self.config.epoch, i, iter_num, loss.item()))
                if self.config.visdom:
                    error = OrderedDict([('loss:', loss.item())])
                    self.visual.plot_current_errors(epoch, i / iter_num, error)

            if (epoch + 1) % self.config.epoch_show == 0:
                print('epoch: [%d/%d], epoch_loss: [%.4f]' %
                      (epoch, self.config.epoch, loss_epoch / iter_num),
                      file=self.log_output)
                if self.config.visdom:
                    avg_err = OrderedDict([('avg_loss', loss_epoch / iter_num)
                                           ])
                    self.visual.plot_current_errors(epoch, i / iter_num,
                                                    avg_err, 1)
                    y_show = torch.mean(torch.cat(
                        [y_pred[i] for i in self.select], dim=1),
                                        dim=1,
                                        keepdim=True)
                    img = OrderedDict([
                        ('origin', x.cpu()[0] * self.std + self.mean),
                        ('label', y.cpu()[0][0]),
                        ('pred_label', y_show.cpu().data[0][0])
                    ])
                    self.visual.plot_current_img(img)

            # if self.config.val and (epoch + 1) % self.config.epoch_val == 0:
            #     mae = self.validation()
            #     print('--- Best MAE: %.2f, Curr MAE: %.2f ---' % (best_mae, mae))
            #     print('--- Best MAE: %.2f, Curr MAE: %.2f ---' % (best_mae, mae), file=self.log_output)
            #     if best_mae > mae:
            #         best_mae = mae
            #         torch.save(self.net.state_dict(), '%s/models/best.pth' % self.config.save_fold)
            if (epoch + 1) % self.config.epoch_save == 0:
                torch.save(
                    self.net.state_dict(), '%s/models/epoch_%d.pth' %
                    (self.config.save_fold, epoch + 1))
        torch.save(self.net.state_dict(),
                   '%s/models/final.pth' % self.config.save_fold)