class TrainModel(TemplateModel): def __init__(self, argus=args): super(TrainModel, self).__init__() self.args = argus self.writer = SummaryWriter('log_new') self.step = 0 self.epoch = 0 self.best_error = float('Inf') self.device = torch.device( "cuda:%d" % args.cuda if torch.cuda.is_available() else "cpu") self.model = Stage1Model().to(self.device) if self.args.pretrainA: self.load_pretrained("model1") if self.args.select_net == 1: self.select_net = SelectNet_resnet().to(self.device) elif self.args.select_net == 0: self.select_net = SelectNet().to(self.device) if self.args.pretrainB: self.load_pretrained("select_net", self.args.select_net) self.optimizer = optim.Adam(self.model.parameters(), self.args.lr) self.optimizer_select = optim.Adam(self.select_net.parameters(), self.args.lr_s) self.criterion = nn.CrossEntropyLoss() self.metric = nn.CrossEntropyLoss() self.regress_loss = nn.SmoothL1Loss() self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, step_size=5, gamma=0.5) self.scheduler3 = optim.lr_scheduler.StepLR(self.optimizer_select, step_size=5, gamma=0.5) self.train_loader = dataloader['train'] self.eval_loader = dataloader['val'] if self.args.select_net == 1: self.ckpt_dir = "checkpoints_AB_res/%s" % uuid else: self.ckpt_dir = "checkpoints_AB_custom/%s" % uuid self.display_freq = args.display_freq # call it to check all members have been intiated self.check_init() def train_loss(self, batch): image, label = batch['image'].to(self.device), batch['labels'].to( self.device) orig, orig_label = batch['orig'].to( self.device), batch['orig_label'].to(self.device) N, L, H, W = orig_label.shape # Get stage1 predict mask (corase mask) stage1_pred = F.softmax(self.model(image), dim=1) assert stage1_pred.shape == (N, 9, 128, 128) # Mask2Theta theta = self.select_net(stage1_pred) assert theta.shape == (N, 6, 2, 3) """"" Using original mask groundtruth to calc theta_groundtruth """ "" assert orig_label.shape == (N, 9, 1024, 1024) cens = torch.floor(calc_centroid(orig_label)) points = torch.floor( torch.cat([cens[:, 1:6], cens[:, 6:9].mean(dim=1, keepdim=True)], dim=1)) theta_label = torch.zeros((N, 6, 2, 3), device=self.device, requires_grad=False) for i in range(6): theta_label[:, i, 0, 0] = (81. - 1.) / (W - 1) theta_label[:, i, 0, 2] = -1. + (2. * points[:, i, 1]) / (W - 1) theta_label[:, i, 1, 1] = (81. - 1.) / (H - 1) theta_label[:, i, 1, 2] = -1. + (2. * points[:, i, 0]) / (H - 1) # Calc Regression loss, Loss func: Smooth L1 loss loss = self.regress_loss(theta, theta_label) return loss def eval_error(self): loss_list = [] step = 0 for batch in self.eval_loader: step += 1 image, label = batch['image'].to(self.device), batch['labels'].to( self.device) orig, orig_label = batch['orig'].to( self.device), batch['orig_label'].to(self.device) N, L, H, W = orig_label.shape # Get Stage1 mask predict stage1_pred = F.softmax(self.model(image), dim=1) assert stage1_pred.shape == (N, 9, 128, 128) # imshow stage1 mask predict stage1_pred_grid = torchvision.utils.make_grid( stage1_pred.argmax(dim=1, keepdim=True)) self.writer.add_image("stage1 predict%s" % uuid, stage1_pred_grid, step) # Stage1Mask to Affine Theta theta = self.select_net(stage1_pred) assert theta.shape == (N, 6, 2, 3) # Calculate Affine theta ground truth assert orig_label.shape == (N, 9, 1024, 1024) cens = torch.floor(calc_centroid(orig_label)) assert cens.shape == (N, 9, 2) points = torch.floor( torch.cat( [cens[:, 1:6], cens[:, 6:9].mean(dim=1, keepdim=True)], dim=1)) theta_label = torch.zeros((N, 6, 2, 3), device=self.device, requires_grad=False) for i in range(6): theta_label[:, i, 0, 0] = (81. - 1.) / (W - 1) theta_label[:, i, 0, 2] = -1. + (2. * points[:, i, 1]) / (W - 1) theta_label[:, i, 1, 1] = (81. - 1.) / (H - 1) theta_label[:, i, 1, 2] = -1. + (2. * points[:, i, 0]) / (H - 1) # calc regression loss loss = self.regress_loss(theta, theta_label) loss_list.append(loss.item()) # imshow cropped parts temp = [] for i in range(theta.shape[1]): test = theta[:, i] grid = F.affine_grid(theta=test, size=[N, 3, 81, 81], align_corners=True) temp.append( F.grid_sample(input=orig, grid=grid, align_corners=True)) parts = torch.stack(temp, dim=1) assert parts.shape == (N, 6, 3, 81, 81) for i in range(6): parts_grid = torchvision.utils.make_grid(parts[:, i].detach().cpu()) self.writer.add_image('croped_parts_%s_%d' % (uuid, i), parts_grid, self.step) return np.mean(loss_list) def train(self): self.model.train() self.select_net.train() self.epoch += 1 for batch in self.train_loader: self.step += 1 self.optimizer.zero_grad() self.optimizer_select.zero_grad() loss = self.train_loss(batch) loss.backward() self.optimizer_select.step() self.optimizer.step() if self.step % self.display_freq == 0: self.writer.add_scalar('loss_%s' % uuid, torch.mean(loss).item(), self.step) print('epoch {}\tstep {}\tloss {:.3}'.format( self.epoch, self.step, torch.mean(loss).item())) def eval(self): self.model.eval() self.select_net.eval() error = self.eval_error() if error < self.best_error: self.best_error = error self.save_state(os.path.join(self.ckpt_dir, 'best.pth.tar'), False) self.save_state( os.path.join(self.ckpt_dir, '{}.pth.tar'.format(self.epoch))) self.writer.add_scalar('error_%s' % uuid, error, self.epoch) print('epoch {}\terror {:.3}\tbest_error {:.3}'.format( self.epoch, error, self.best_error)) return error def save_state(self, fname, optim=True): state = {} if isinstance(self.model, torch.nn.DataParallel): state['model1'] = self.model.module.state_dict() state['select_net'] = self.select_net.module.state_dict() else: state['model1'] = self.model.state_dict() state['select_net'] = self.select_net.state_dict() if optim: state['optimizer'] = self.optimizer.state_dict() state['optimizer_select'] = self.optimizer_select.state_dict() state['step'] = self.step state['epoch'] = self.epoch state['best_error'] = self.best_error torch.save(state, fname) print('save model at {}'.format(fname)) def load_pretrained(self, model, mode=None): path_modelA = os.path.join( "/home/yinzi/data4/new_train/checkpoints_A/88736bbe", 'best.pth.tar') if mode == 0: path_modelB_select_net = os.path.join( "/home/yinzi/data4/new_train/checkpoints_B_selectnet/cab2d814", 'best.pth.tar') elif mode == 1: path_modelB_select_net = os.path.join( "/home/yinzi/data4/new_train/checkpoints_B_resnet/2a8e078e", 'best.pth.tar') if model == 'model1': fname = path_modelA state = torch.load(fname, map_location=self.device) self.model.load_state_dict(state['model1']) print("load from" + fname) elif model == 'select_net': fname = path_modelB_select_net state = torch.load(fname, map_location=self.device) self.select_net.load_state_dict(state['select_net']) print("load from" + fname)
class TrainModel(TemplateModel): def __init__(self, argus=args): super(TrainModel, self).__init__() self.args = argus self.writer = SummaryWriter('log') self.step = 0 self.epoch = 0 self.best_error = float('Inf') self.device = torch.device( "cuda:%d" % args.cuda if torch.cuda.is_available() else "cpu") self.model = Stage1Model().to(self.device) # self.load_pretrained("model1") self.model2 = Stage2Model().to(self.device) self.load_pretrained("model2") self.select_net = SelectNet().to(self.device) self.load_pretrained("select_net") self.optimizer = optim.Adam(self.model.parameters(), self.args.lr) self.optimizer2 = optim.Adam(self.model2.parameters(), self.args.lr2) self.optimizer_select = optim.Adam(self.select_net.parameters(), self.args.lr_s) self.criterion = nn.CrossEntropyLoss() self.metric = nn.CrossEntropyLoss() self.regress_loss = nn.SmoothL1Loss() self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, step_size=5, gamma=0.5) self.scheduler2 = optim.lr_scheduler.StepLR(self.optimizer2, step_size=5, gamma=0.5) self.scheduler3 = optim.lr_scheduler.StepLR(self.optimizer_select, step_size=5, gamma=0.5) self.train_loader = dataloader['train'] self.eval_loader = dataloader['val'] self.ckpt_dir = "checkpoints_BC/%s" % uuid self.display_freq = args.display_freq # call it to check all members have been intiated self.check_init() def train_loss(self, batch): image, label = batch['image'].to(self.device), batch['labels'].to( self.device) orig, orig_label = batch['orig'].to( self.device), batch['orig_label'].to(self.device) parts_mask = batch['parts_mask_gt'].to(self.device) N, L, H, W = orig_label.shape assert label.shape == (N, 9, 128, 128) theta = self.select_net(label) assert theta.shape == (N, 6, 2, 3) parts, parts_label, _ = affine_crop(img=orig, label=orig_label, theta_in=theta, map_location=self.device) assert parts.grad_fn is not None assert parts.shape == (N, 6, 3, 81, 81) stage2_pred = self.model2(parts) assert len(stage2_pred) == 6 loss = [] for i in range(6): loss.append(self.criterion(stage2_pred[i], parts_mask[:, i].long())) loss = torch.stack(loss) return loss def eval_error(self): loss_list = [] for batch in self.eval_loader: image, label = batch['image'].to(self.device), batch['labels'].to( self.device) orig, orig_label = batch['orig'].to( self.device), batch['orig_label'].to(self.device) parts_mask = batch['parts_mask_gt'].to(self.device) N, L, H, W = orig_label.shape assert label.shape == (N, 9, 128, 128) theta = self.select_net(label) assert theta.shape == (N, 6, 2, 3) parts, parts_label, _ = affine_crop(img=orig, label=orig_label, theta_in=theta, map_location=self.device) assert parts.grad_fn is not None assert parts.shape == (N, 6, 3, 81, 81) stage2_pred = self.model2(parts) assert len(stage2_pred) == 6 loss = [] for i in range(6): loss.append( self.criterion(stage2_pred[i], parts_mask[:, i].long()).item()) loss_list.append(np.mean(loss)) return np.mean(loss_list) def train(self): # self.model.train() self.model2.train() self.select_net.train() self.epoch += 1 for batch in self.train_loader: self.step += 1 # self.optimizer.zero_grad() self.optimizer2.zero_grad() self.optimizer_select.zero_grad() loss = self.train_loss(batch) # [1,1,1,1,1,1] weight for 6 parts loss loss.backward( torch.ones(6, device=self.device, requires_grad=False)) self.optimizer2.step() self.optimizer_select.step() # self.optimizer.step() if self.step % self.display_freq == 0: self.writer.add_scalar('loss_%s' % uuid, torch.mean(loss).item(), self.step) print('epoch {}\tstep {}\tloss {:.3}'.format( self.epoch, self.step, torch.mean(loss).item())) def eval(self): # self.model.eval() self.model2.eval() self.select_net.eval() error = self.eval_error() if error < self.best_error: self.best_error = error self.save_state(os.path.join(self.ckpt_dir, 'best.pth.tar'), False) self.save_state( os.path.join(self.ckpt_dir, '{}.pth.tar'.format(self.epoch))) self.writer.add_scalar('error_%s' % uuid, error, self.epoch) print('epoch {}\terror {:.3}\tbest_error {:.3}'.format( self.epoch, error, self.best_error)) return error def save_state(self, fname, optim=True): state = {} if isinstance(self.model, torch.nn.DataParallel): state['model1'] = self.model.module.state_dict() state['model2'] = self.model2.module.state_dict() state['select_net'] = self.select_net.module.state_dict() else: state['model1'] = self.model.state_dict() state['model2'] = self.model2.state_dict() state['select_net'] = self.select_net.state_dict() if optim: state['optimizer'] = self.optimizer.state_dict() state['optimizer2'] = self.optimizer2.state_dict() state['optimizer_select'] = self.optimizer_select.state_dict() state['step'] = self.step state['epoch'] = self.epoch state['best_error'] = self.best_error torch.save(state, fname) print('save model at {}'.format(fname)) def load_pretrained(self, model): if model == 'model1': fname = "a" state = torch.load(fname, map_location=self.device) self.model1.load_state_dict(state['model1']) elif model == 'model2': fname = "/home/yinzi/data4/new_train/checkpoints_C/02a38440/best.pth.tar" state = torch.load(fname, map_location=self.device) self.model2.load_state_dict(state['model2']) elif model == 'select_net': fname = "/home/yinzi/data4/new_train/checkpoints_AB/6b4324c6/best.pth.tar" state = torch.load(fname, map_location=self.device) self.select_net.load_state_dict(state['select_net'])