def __init__(self, argus=args):
        super(TrainModel, self).__init__()
        self.args = argus
        self.writer = SummaryWriter('log_new')
        self.step = 0
        self.epoch = 0
        self.best_error = float('Inf')

        self.device = torch.device(
            "cuda:%d" % args.cuda if torch.cuda.is_available() else "cpu")

        self.model = Stage1Model().to(self.device)
        if self.args.pretrainA:
            self.load_pretrained("model1")
        if self.args.select_net == 1:
            self.select_net = SelectNet_resnet().to(self.device)
        elif self.args.select_net == 0:
            self.select_net = SelectNet().to(self.device)
        if self.args.pretrainB:
            self.load_pretrained("select_net", self.args.select_net)
        self.optimizer = optim.Adam(self.model.parameters(), self.args.lr)
        self.optimizer_select = optim.Adam(self.select_net.parameters(),
                                           self.args.lr_s)

        self.criterion = nn.CrossEntropyLoss()
        self.metric = nn.CrossEntropyLoss()
        self.regress_loss = nn.SmoothL1Loss()
        self.scheduler = optim.lr_scheduler.StepLR(self.optimizer,
                                                   step_size=5,
                                                   gamma=0.5)
        self.scheduler3 = optim.lr_scheduler.StepLR(self.optimizer_select,
                                                    step_size=5,
                                                    gamma=0.5)

        self.train_loader = dataloader['train']
        self.eval_loader = dataloader['val']
        if self.args.select_net == 1:
            self.ckpt_dir = "checkpoints_AB_res/%s" % uuid
        else:
            self.ckpt_dir = "checkpoints_AB_custom/%s" % uuid
        self.display_freq = args.display_freq
        # call it to check all members have been intiated
        self.check_init()
示例#2
0
from tensorboardX import SummaryWriter
from dataset import HelenDataset
from torchvision import transforms
from preprocess import ToPILImage, ToTensor, OrigPad, Resize
from torch.utils.data import DataLoader
from helper_funcs import F1Score, calc_centroid, affine_crop, affine_mapback
import torch.nn.functional as F
import torchvision
import torch
import os

writer = SummaryWriter('log')
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
model1 = FaceModel().to(device)
model2 = Stage2Model().to(device)
select_model = SelectNet().to(device)
# load state
# path_model2 = os.path.join("/home/yinzi/data4/new_train/checkpoints_C/02a38440", "best.pth.tar")
# path_model2 = os.path.join("/home/yinzi/data4/new_train/checkpoints_C/ca8f5c52", "best.pth.tar")
# path_model2 = os.path.join("/home/yinzi/data4/new_train/checkpoints_C/b9d37dbc", "best.pth.tar")
# path_model2 = os.path.join("/home/yinzi/data4/new_train/checkpoints_C/49997f1e", "best.pth.tar")
# path_model2 = os.path.join("/home/yinzi/data4/new_train/checkpoints_C/396e4702", "best.pth.tar")
path_model2 = os.path.join(
    "/home/yinzi/data4/new_train/checkpoints_C/396e4702", "best.pth.tar")
# path_model2 = os.path.join("/home/yinzi/data4/new_train/checkpoints_C/1daed2c2", "best.pth.tar")
# path_model2 = os.path.join("/home/yinzi/data4/new_train/checkpoints_ABC/ea3c3972", "best.pth.tar")
path_select = os.path.join(
    "/home/yinzi/data4/new_train/checkpoints_AB/6b4324c6", "best.pth.tar")
#396e4702的嘴得分是0.9166 overall 0.865(0.869)
#396e4702的单独最佳得分是0.8714
#b9d37dbc overall 0.854
class TrainModel(TemplateModel):
    def __init__(self, argus=args):
        super(TrainModel, self).__init__()
        self.args = argus
        self.writer = SummaryWriter('log_new')
        self.step = 0
        self.epoch = 0
        self.best_error = float('Inf')

        self.device = torch.device(
            "cuda:%d" % args.cuda if torch.cuda.is_available() else "cpu")

        self.model = Stage1Model().to(self.device)
        if self.args.pretrainA:
            self.load_pretrained("model1")
        if self.args.select_net == 1:
            self.select_net = SelectNet_resnet().to(self.device)
        elif self.args.select_net == 0:
            self.select_net = SelectNet().to(self.device)
        if self.args.pretrainB:
            self.load_pretrained("select_net", self.args.select_net)
        self.optimizer = optim.Adam(self.model.parameters(), self.args.lr)
        self.optimizer_select = optim.Adam(self.select_net.parameters(),
                                           self.args.lr_s)

        self.criterion = nn.CrossEntropyLoss()
        self.metric = nn.CrossEntropyLoss()
        self.regress_loss = nn.SmoothL1Loss()
        self.scheduler = optim.lr_scheduler.StepLR(self.optimizer,
                                                   step_size=5,
                                                   gamma=0.5)
        self.scheduler3 = optim.lr_scheduler.StepLR(self.optimizer_select,
                                                    step_size=5,
                                                    gamma=0.5)

        self.train_loader = dataloader['train']
        self.eval_loader = dataloader['val']
        if self.args.select_net == 1:
            self.ckpt_dir = "checkpoints_AB_res/%s" % uuid
        else:
            self.ckpt_dir = "checkpoints_AB_custom/%s" % uuid
        self.display_freq = args.display_freq
        # call it to check all members have been intiated
        self.check_init()

    def train_loss(self, batch):
        image, label = batch['image'].to(self.device), batch['labels'].to(
            self.device)
        orig, orig_label = batch['orig'].to(
            self.device), batch['orig_label'].to(self.device)
        N, L, H, W = orig_label.shape

        # Get stage1 predict mask (corase mask)
        stage1_pred = F.softmax(self.model(image), dim=1)
        assert stage1_pred.shape == (N, 9, 128, 128)

        # Mask2Theta
        theta = self.select_net(stage1_pred)
        assert theta.shape == (N, 6, 2, 3)
        """""
            Using original mask groundtruth to calc theta_groundtruth
        """ ""
        assert orig_label.shape == (N, 9, 1024, 1024)
        cens = torch.floor(calc_centroid(orig_label))
        points = torch.floor(
            torch.cat([cens[:, 1:6], cens[:, 6:9].mean(dim=1, keepdim=True)],
                      dim=1))
        theta_label = torch.zeros((N, 6, 2, 3),
                                  device=self.device,
                                  requires_grad=False)
        for i in range(6):
            theta_label[:, i, 0, 0] = (81. - 1.) / (W - 1)
            theta_label[:, i, 0, 2] = -1. + (2. * points[:, i, 1]) / (W - 1)
            theta_label[:, i, 1, 1] = (81. - 1.) / (H - 1)
            theta_label[:, i, 1, 2] = -1. + (2. * points[:, i, 0]) / (H - 1)

        # Calc Regression loss, Loss func: Smooth L1 loss
        loss = self.regress_loss(theta, theta_label)
        return loss

    def eval_error(self):
        loss_list = []
        step = 0
        for batch in self.eval_loader:
            step += 1
            image, label = batch['image'].to(self.device), batch['labels'].to(
                self.device)
            orig, orig_label = batch['orig'].to(
                self.device), batch['orig_label'].to(self.device)
            N, L, H, W = orig_label.shape
            # Get Stage1 mask predict
            stage1_pred = F.softmax(self.model(image), dim=1)
            assert stage1_pred.shape == (N, 9, 128, 128)

            # imshow stage1 mask predict
            stage1_pred_grid = torchvision.utils.make_grid(
                stage1_pred.argmax(dim=1, keepdim=True))
            self.writer.add_image("stage1 predict%s" % uuid, stage1_pred_grid,
                                  step)

            # Stage1Mask to Affine Theta
            theta = self.select_net(stage1_pred)
            assert theta.shape == (N, 6, 2, 3)

            # Calculate Affine theta ground truth
            assert orig_label.shape == (N, 9, 1024, 1024)
            cens = torch.floor(calc_centroid(orig_label))
            assert cens.shape == (N, 9, 2)
            points = torch.floor(
                torch.cat(
                    [cens[:, 1:6], cens[:, 6:9].mean(dim=1, keepdim=True)],
                    dim=1))
            theta_label = torch.zeros((N, 6, 2, 3),
                                      device=self.device,
                                      requires_grad=False)
            for i in range(6):
                theta_label[:, i, 0, 0] = (81. - 1.) / (W - 1)
                theta_label[:, i, 0,
                            2] = -1. + (2. * points[:, i, 1]) / (W - 1)
                theta_label[:, i, 1, 1] = (81. - 1.) / (H - 1)
                theta_label[:, i, 1,
                            2] = -1. + (2. * points[:, i, 0]) / (H - 1)

            # calc regression loss
            loss = self.regress_loss(theta, theta_label)
            loss_list.append(loss.item())

        # imshow cropped parts
        temp = []
        for i in range(theta.shape[1]):
            test = theta[:, i]
            grid = F.affine_grid(theta=test,
                                 size=[N, 3, 81, 81],
                                 align_corners=True)
            temp.append(
                F.grid_sample(input=orig, grid=grid, align_corners=True))
        parts = torch.stack(temp, dim=1)
        assert parts.shape == (N, 6, 3, 81, 81)
        for i in range(6):
            parts_grid = torchvision.utils.make_grid(parts[:,
                                                           i].detach().cpu())
            self.writer.add_image('croped_parts_%s_%d' % (uuid, i), parts_grid,
                                  self.step)
        return np.mean(loss_list)

    def train(self):
        self.model.train()
        self.select_net.train()
        self.epoch += 1
        for batch in self.train_loader:
            self.step += 1
            self.optimizer.zero_grad()
            self.optimizer_select.zero_grad()
            loss = self.train_loss(batch)
            loss.backward()
            self.optimizer_select.step()
            self.optimizer.step()

            if self.step % self.display_freq == 0:
                self.writer.add_scalar('loss_%s' % uuid,
                                       torch.mean(loss).item(), self.step)
                print('epoch {}\tstep {}\tloss {:.3}'.format(
                    self.epoch, self.step,
                    torch.mean(loss).item()))

    def eval(self):
        self.model.eval()
        self.select_net.eval()
        error = self.eval_error()

        if error < self.best_error:
            self.best_error = error
            self.save_state(os.path.join(self.ckpt_dir, 'best.pth.tar'), False)
        self.save_state(
            os.path.join(self.ckpt_dir, '{}.pth.tar'.format(self.epoch)))
        self.writer.add_scalar('error_%s' % uuid, error, self.epoch)
        print('epoch {}\terror {:.3}\tbest_error {:.3}'.format(
            self.epoch, error, self.best_error))

        return error

    def save_state(self, fname, optim=True):
        state = {}

        if isinstance(self.model, torch.nn.DataParallel):
            state['model1'] = self.model.module.state_dict()
            state['select_net'] = self.select_net.module.state_dict()
        else:
            state['model1'] = self.model.state_dict()
            state['select_net'] = self.select_net.state_dict()

        if optim:
            state['optimizer'] = self.optimizer.state_dict()
            state['optimizer_select'] = self.optimizer_select.state_dict()

        state['step'] = self.step
        state['epoch'] = self.epoch
        state['best_error'] = self.best_error
        torch.save(state, fname)
        print('save model at {}'.format(fname))

    def load_pretrained(self, model, mode=None):
        path_modelA = os.path.join(
            "/home/yinzi/data4/new_train/checkpoints_A/88736bbe",
            'best.pth.tar')
        if mode == 0:
            path_modelB_select_net = os.path.join(
                "/home/yinzi/data4/new_train/checkpoints_B_selectnet/cab2d814",
                'best.pth.tar')
        elif mode == 1:
            path_modelB_select_net = os.path.join(
                "/home/yinzi/data4/new_train/checkpoints_B_resnet/2a8e078e",
                'best.pth.tar')

        if model == 'model1':
            fname = path_modelA
            state = torch.load(fname, map_location=self.device)
            self.model.load_state_dict(state['model1'])
            print("load from" + fname)
        elif model == 'select_net':
            fname = path_modelB_select_net
            state = torch.load(fname, map_location=self.device)
            self.select_net.load_state_dict(state['select_net'])
            print("load from" + fname)
class TrainModel(TemplateModel):
    def __init__(self, argus=args):
        super(TrainModel, self).__init__()
        self.args = argus
        self.writer = SummaryWriter('log')
        self.step = 0
        self.epoch = 0
        self.best_error = float('Inf')

        self.device = torch.device(
            "cuda:%d" % args.cuda if torch.cuda.is_available() else "cpu")

        self.model = Stage1Model().to(self.device)
        # self.load_pretrained("model1")
        self.model2 = Stage2Model().to(self.device)
        self.load_pretrained("model2")
        self.select_net = SelectNet().to(self.device)
        self.load_pretrained("select_net")
        self.optimizer = optim.Adam(self.model.parameters(), self.args.lr)
        self.optimizer2 = optim.Adam(self.model2.parameters(), self.args.lr2)
        self.optimizer_select = optim.Adam(self.select_net.parameters(),
                                           self.args.lr_s)

        self.criterion = nn.CrossEntropyLoss()
        self.metric = nn.CrossEntropyLoss()
        self.regress_loss = nn.SmoothL1Loss()
        self.scheduler = optim.lr_scheduler.StepLR(self.optimizer,
                                                   step_size=5,
                                                   gamma=0.5)
        self.scheduler2 = optim.lr_scheduler.StepLR(self.optimizer2,
                                                    step_size=5,
                                                    gamma=0.5)
        self.scheduler3 = optim.lr_scheduler.StepLR(self.optimizer_select,
                                                    step_size=5,
                                                    gamma=0.5)

        self.train_loader = dataloader['train']
        self.eval_loader = dataloader['val']

        self.ckpt_dir = "checkpoints_BC/%s" % uuid
        self.display_freq = args.display_freq

        # call it to check all members have been intiated
        self.check_init()

    def train_loss(self, batch):
        image, label = batch['image'].to(self.device), batch['labels'].to(
            self.device)
        orig, orig_label = batch['orig'].to(
            self.device), batch['orig_label'].to(self.device)
        parts_mask = batch['parts_mask_gt'].to(self.device)

        N, L, H, W = orig_label.shape
        assert label.shape == (N, 9, 128, 128)

        theta = self.select_net(label)
        assert theta.shape == (N, 6, 2, 3)

        parts, parts_label, _ = affine_crop(img=orig,
                                            label=orig_label,
                                            theta_in=theta,
                                            map_location=self.device)
        assert parts.grad_fn is not None
        assert parts.shape == (N, 6, 3, 81, 81)

        stage2_pred = self.model2(parts)
        assert len(stage2_pred) == 6

        loss = []
        for i in range(6):
            loss.append(self.criterion(stage2_pred[i], parts_mask[:,
                                                                  i].long()))
        loss = torch.stack(loss)
        return loss

    def eval_error(self):
        loss_list = []
        for batch in self.eval_loader:
            image, label = batch['image'].to(self.device), batch['labels'].to(
                self.device)
            orig, orig_label = batch['orig'].to(
                self.device), batch['orig_label'].to(self.device)
            parts_mask = batch['parts_mask_gt'].to(self.device)

            N, L, H, W = orig_label.shape
            assert label.shape == (N, 9, 128, 128)

            theta = self.select_net(label)
            assert theta.shape == (N, 6, 2, 3)

            parts, parts_label, _ = affine_crop(img=orig,
                                                label=orig_label,
                                                theta_in=theta,
                                                map_location=self.device)
            assert parts.grad_fn is not None
            assert parts.shape == (N, 6, 3, 81, 81)

            stage2_pred = self.model2(parts)
            assert len(stage2_pred) == 6

            loss = []
            for i in range(6):
                loss.append(
                    self.criterion(stage2_pred[i],
                                   parts_mask[:, i].long()).item())
            loss_list.append(np.mean(loss))
        return np.mean(loss_list)

    def train(self):
        # self.model.train()
        self.model2.train()
        self.select_net.train()
        self.epoch += 1
        for batch in self.train_loader:
            self.step += 1
            # self.optimizer.zero_grad()
            self.optimizer2.zero_grad()
            self.optimizer_select.zero_grad()
            loss = self.train_loss(batch)
            # [1,1,1,1,1,1] weight for 6 parts loss
            loss.backward(
                torch.ones(6, device=self.device, requires_grad=False))
            self.optimizer2.step()
            self.optimizer_select.step()
            # self.optimizer.step()

            if self.step % self.display_freq == 0:
                self.writer.add_scalar('loss_%s' % uuid,
                                       torch.mean(loss).item(), self.step)
                print('epoch {}\tstep {}\tloss {:.3}'.format(
                    self.epoch, self.step,
                    torch.mean(loss).item()))

    def eval(self):
        # self.model.eval()
        self.model2.eval()
        self.select_net.eval()
        error = self.eval_error()
        if error < self.best_error:
            self.best_error = error
            self.save_state(os.path.join(self.ckpt_dir, 'best.pth.tar'), False)
        self.save_state(
            os.path.join(self.ckpt_dir, '{}.pth.tar'.format(self.epoch)))
        self.writer.add_scalar('error_%s' % uuid, error, self.epoch)
        print('epoch {}\terror {:.3}\tbest_error {:.3}'.format(
            self.epoch, error, self.best_error))

        return error

    def save_state(self, fname, optim=True):
        state = {}

        if isinstance(self.model, torch.nn.DataParallel):
            state['model1'] = self.model.module.state_dict()
            state['model2'] = self.model2.module.state_dict()
            state['select_net'] = self.select_net.module.state_dict()
        else:
            state['model1'] = self.model.state_dict()
            state['model2'] = self.model2.state_dict()
            state['select_net'] = self.select_net.state_dict()

        if optim:
            state['optimizer'] = self.optimizer.state_dict()
            state['optimizer2'] = self.optimizer2.state_dict()
            state['optimizer_select'] = self.optimizer_select.state_dict()

        state['step'] = self.step
        state['epoch'] = self.epoch
        state['best_error'] = self.best_error
        torch.save(state, fname)
        print('save model at {}'.format(fname))

    def load_pretrained(self, model):
        if model == 'model1':
            fname = "a"
            state = torch.load(fname, map_location=self.device)
            self.model1.load_state_dict(state['model1'])
        elif model == 'model2':
            fname = "/home/yinzi/data4/new_train/checkpoints_C/02a38440/best.pth.tar"
            state = torch.load(fname, map_location=self.device)
            self.model2.load_state_dict(state['model2'])
        elif model == 'select_net':
            fname = "/home/yinzi/data4/new_train/checkpoints_AB/6b4324c6/best.pth.tar"
            state = torch.load(fname, map_location=self.device)
            self.select_net.load_state_dict(state['select_net'])