def __init__(self, lr, train_loader, model, optimizer, scheduler, logger, device): self.lr = lr self.train_loader = train_loader self.model = model self.optimizer = optimizer self.scheduler = scheduler self.device = device self.lossfn = LossFn(self.device) self.logger = logger self.run_count = 0 self.scalar_info = {}
def __init__(self, lr, train_loader, valid_loader, model_1, model_2, optimizer_1, optimizer_2, scheduler_1, scheduler_2, logger, device): self.lr = lr self.train_loader = train_loader self.valid_loader = valid_loader self.model_1 = model_1 self.optimizer_1 = optimizer_1 self.scheduler_1 = scheduler_1 self.model_2 = model_2 self.optimizer_2 = optimizer_2 self.scheduler_2 = scheduler_2 self.device = device self.lossfn = LossFn(self.device, lam=2) self.logger = logger self.run_count = 0 self.scalar_info = {} self.config = Config()
class PNetTrainer(object): def __init__(self, lr, train_loader, model, optimizer, scheduler, logger, device): self.lr = lr self.train_loader = train_loader self.model = model self.optimizer = optimizer self.scheduler = scheduler self.device = device self.lossfn = LossFn(self.device) self.logger = logger self.run_count = 0 self.scalar_info = {} def compute_accuracy(self, prob_cls, gt_cls): #we only need the detection which >= 0 prob_cls = torch.squeeze(prob_cls) mask = torch.ge(gt_cls, 0) #get valid element valid_gt_cls = torch.masked_select(gt_cls, mask) valid_prob_cls = torch.masked_select(prob_cls, mask) size = min(valid_gt_cls.size()[0], valid_prob_cls.size()[0]) prob_ones = torch.ge(valid_prob_cls, 0.6).float() right_ones = torch.eq(prob_ones, valid_gt_cls.float()).float() return torch.div(torch.mul(torch.sum(right_ones), float(1.0)), float(size)) def update_lr(self, epoch): """ update learning rate of optimizers :param epoch: current training epoch """ # update learning rate of model optimizer for param_group in self.optimizer.param_groups: param_group['lr'] = self.lr def train(self, epoch): cls_loss_ = AverageMeter() box_offset_loss_ = AverageMeter() total_loss_ = AverageMeter() accuracy_ = AverageMeter() self.scheduler.step() self.model.train() for batch_idx, (data, target) in enumerate(self.train_loader): gt_label = target['label'] gt_bbox = target['bbox_target'] data, gt_label, gt_bbox = data.to(self.device), gt_label.to( self.device), gt_bbox.to(self.device).float() cls_pred, box_offset_pred = self.model(data) # compute the loss cls_loss = self.lossfn.cls_loss(gt_label, cls_pred) box_offset_loss = self.lossfn.box_loss(gt_label, gt_bbox, box_offset_pred) total_loss = cls_loss + box_offset_loss accuracy = self.compute_accuracy(cls_pred, gt_label) self.optimizer.zero_grad() total_loss.backward() self.optimizer.step() cls_loss_.update(cls_loss, data.size(0)) box_offset_loss_.update(box_offset_loss, data.size(0)) total_loss_.update(total_loss, data.size(0)) accuracy_.update(accuracy, data.size(0)) print( 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.6f}' .format(epoch, batch_idx * len(data), len(self.train_loader.dataset), 100. * batch_idx / len(self.train_loader), total_loss.item(), accuracy.item())) self.scalar_info['cls_loss'] = cls_loss_.avg self.scalar_info['box_offset_loss'] = box_offset_loss_.avg self.scalar_info['total_loss'] = total_loss_.avg self.scalar_info['accuracy'] = accuracy_.avg self.scalar_info['lr'] = self.scheduler.get_lr()[0] if self.logger is not None: for tag, value in list(self.scalar_info.items()): self.logger.scalar_summary(tag, value, self.run_count) self.scalar_info = {} self.run_count += 1 print("|===>Loss: {:.4f}".format(total_loss_.avg)) return cls_loss_.avg, box_offset_loss_.avg, total_loss_.avg, accuracy_.avg
class AlexNetTrainer(object): def __init__(self, lr, train_loader, valid_loader, model_1, model_2, optimizer_1, optimizer_2, scheduler_1, scheduler_2, logger, device): self.lr = lr self.train_loader = train_loader self.valid_loader = valid_loader self.model_1 = model_1 self.optimizer_1 = optimizer_1 self.scheduler_1 = scheduler_1 self.model_2 = model_2 self.optimizer_2 = optimizer_2 self.scheduler_2 = scheduler_2 self.device = device self.lossfn = LossFn(self.device, lam=2) self.logger = logger self.run_count = 0 self.scalar_info = {} self.config = Config() # hook # self.handle = self.model_2.channelgroup_2.group[0].register_backward_hook(self.for_hook) def compute_accuracy(self, prob_cls, gt_cls): pred_cls = torch.max(prob_cls, 1)[1].squeeze() accuracy = float((pred_cls == gt_cls).sum()) / float(gt_cls.size(0)) return accuracy def update_lr(self, epoch): """ update learning rate of optimizers :param epoch: current training epoch """ # update learning rate of model optimizer for param_group in self.optimizer_1.param_groups: param_group['lr'] = self.lr for param_group in self.optimizer_2.param_groups: param_group['lr'] = self.lr def show_image_grid(self, img_origin, img_parts, parts, epoch): print(img_parts.size()) torchvision.utils.save_image(img_origin, self.config.save_path + '/img_origin_grid.jpg', nrow=8, padding=2, normalize=True, range=(0, 1)) torchvision.utils.save_image(img_parts, self.config.save_path + '/img_parts' + '_grid.jpg', nrow=8, padding=2, normalize=True, range=(0, 1)) print(type(parts)) print(parts.shape) grid = torchvision.utils.make_grid(img_origin, nrow=8, padding=2, normalize=True, range=(0, 1)) image_size = 114 parts = parts.transpose((1, 0)) # save grid and parts as numpy grid = grid.cpu().numpy().transpose((1, 2, 0)) np.save(self.config.save_path + '/grid.npy', grid) np.save(self.config.save_path + '/parts.npy', parts) img = cv2.cvtColor(grid * 255, cv2.COLOR_RGB2BGR) l = 24 colors = [(229, 187, 129), (161, 23, 21), (34, 8, 7), (118, 77, 57)] for i in range(parts.shape[0]): box = np.array([ np.maximum(0, (parts[i, 0] - l)), np.maximum(0, (parts[i, 1] - l)), np.minimum(image_size, (parts[i, 0] + l)), np.minimum(image_size, (parts[i, 1] + l)) ]) cv2.rectangle( img, (int(box[0] + i % 8 * (2 + image_size)), int(box[1]) + i // 8 * (2 + image_size)), (int(box[2] + i % 8 * (2 + image_size)), int(box[3]) + i // 8 * (2 + image_size)), colors[0], 2) cv2.imwrite(self.config.save_path + '/plt_image_boxs_epoch' + str(epoch) + '.jpg', img) # 保存图片 return def show_mask(self, mask, epoch): img = (1 - (mask - torch.min(mask)) / (torch.max(mask) - torch.min(mask))) img = torchvision.utils.make_grid(img.view(64, 1, 6, 6), nrow=8, padding=2, normalize=True, range=(0, 1)).cpu().numpy().transpose( (1, 2, 0)) img = cv2.cvtColor(img * 255, cv2.COLOR_RGB2BGR) cv2.imwrite(self.config.save_path + '/mask_' + 'epoch' + str(epoch) + '.jpg', img) # 保存图片 return def for_hook(self, module, input_grad, output_grad): print('\r\nhook:\r\n') print(len(input_grad)) print(len(output_grad)) print(input_grad[0].size()) print(input_grad[1].size()) print(output_grad[0].size()) print(input_grad[0][5][5:10]) print(output_grad[0][5][5:10]) def train(self, epoch): cls_loss_ = AverageMeter() accuracy_ = AverageMeter() accuracy_valid_ = AverageMeter() # 训练集作为模型输入 self.scheduler_1.step() self.scheduler_2.step() self.model_1.train() self.model_2.train() for batch_idx, (data, gt_label) in enumerate(self.train_loader): data, gt_label = data.to(self.device), gt_label.to(self.device) x, mask = self.model_1(data) # test # print(self.model_1.alexnet_1.conv1[0].weight.data) # print(self.model_2.channelgroup_2.group[0].weight.data[5][5:10]) # print(self.model_3.Classify_1.conv1[0].weight.data) # test with torch.no_grad(): parts = part_box(mask) img_parts, parts = get_part(data.cpu(), parts) # (1, 64, 48, 48) img_parts = torch.from_numpy(img_parts).view( img_parts.shape[0], 1, 48, 48).to(self.device) # view(64, 1, 48, 48) if (epoch == 1 or epoch == 5 or epoch == 10 or epoch == 15) and batch_idx == 1: self.show_image_grid(data, img_parts, parts, epoch) self.show_mask(mask, epoch) print('save image and parts in result: ' + self.config.save_path) print('epoch: ' + str(epoch)) print('batch_idx: ' + str(batch_idx)) cls_pred = self.model_2(img_parts, x) # compute the loss cls_loss = self.lossfn.cls_loss(gt_label, cls_pred) accuracy = self.compute_accuracy(cls_pred, gt_label) if epoch >= 0: self.optimizer_1.zero_grad() self.optimizer_2.zero_grad() cls_loss.backward() self.optimizer_1.step() self.optimizer_2.step() cls_loss_.update(cls_loss.item(), data.size(0)) accuracy_.update(accuracy, data.size(0)) if batch_idx % 2000 == 1: print('batch_idx: ', batch_idx) print('Cls loss: ', cls_loss.item()) # 验证集作为模型输入 with torch.no_grad(): self.model_1.eval() self.model_2.eval() for batch_idx, (data, gt_label) in enumerate(self.valid_loader): data, gt_label = data.to(self.device), gt_label.to(self.device) x, mask = self.model_1(data) parts = part_box(mask) img_parts, parts = get_part(data.cpu(), parts) # (4, 64, 48, 48) img_parts = torch.from_numpy(img_parts).view( img_parts.shape[0], 1, 48, 48).to(self.device) cls_pred = self.model_2(img_parts, x) accuracy_valid = self.compute_accuracy(cls_pred, gt_label) accuracy_valid_.update(accuracy_valid, data.size(0)) # 记录数据 self.scalar_info['cls_loss'] = cls_loss_.avg self.scalar_info['accuracy'] = accuracy_.avg self.scalar_info['lr'] = self.scheduler_1.get_lr()[0] # if self.logger is not None: # for tag, value in list(self.scalar_info.items()): # self.logger.scalar_summary(tag, value, self.run_count) # self.scalar_info = {} # self.run_count += 1 print( "\r\nEpoch: {}|===>Train Loss: {:.8f} Train Accuracy: {:.6f} valid Accuracy: {:.6f}\r\n" .format(epoch, cls_loss_.avg, accuracy_.avg, accuracy_valid_.avg)) return cls_loss_.avg, accuracy_.avg, accuracy_valid_.avg
class ONetTrainer(object): def __init__(self, lr, train_loader, model, optimizer, scheduler, logger, device): self.lr = lr self.train_loader = train_loader self.model = model self.optimizer = optimizer self.scheduler = scheduler self.device = device self.lossfn = LossFn(self.device) self.logger = logger self.run_count = 0 self.scalar_info = {} def compute_accuracy(self, prob_cls, gt_cls): #we only need the detection which >= 0 prob_cls = torch.squeeze(prob_cls) mask = torch.ge(gt_cls, 0) #get valid element valid_gt_cls = torch.masked_select(gt_cls, mask) valid_prob_cls = torch.masked_select(prob_cls, mask) size = min(valid_gt_cls.size()[0], valid_prob_cls.size()[0]) prob_ones = torch.ge(valid_prob_cls, 0.6).float() right_ones = torch.eq(prob_ones, valid_gt_cls.float()).float() return torch.div(torch.mul(torch.sum(right_ones), float(1.0)), float(size)) def update_lr(self, epoch): """ update learning rate of optimizers :param epoch: current training epoch """ # update learning rate of model optimizer for param_group in self.optimizer.param_groups: param_group['lr'] = self.lr def train(self, epoch): cls_loss_ = AverageMeter() box_offset_loss_ = AverageMeter() landmark_loss_ = AverageMeter() total_loss_ = AverageMeter() accuracy_ = AverageMeter() self.scheduler.step() self.model.train() for batch_idx, (data, target) in enumerate(self.train_loader): gt_label = target['label'] gt_bbox = target['bbox_target'] gt_landmark = target['landmark_target'] data, gt_label, gt_bbox, gt_landmark = data.to(self.device), gt_label.to( self.device), gt_bbox.to(self.device).float(), gt_landmark.to(self.device).float() cls_pred, box_offset_pred, landmark_offset_pred = self.model(data) # compute the loss cls_loss = self.lossfn.cls_loss(gt_label, cls_pred) box_offset_loss = self.lossfn.box_loss( gt_label, gt_bbox, box_offset_pred) landmark_loss = self.lossfn.landmark_loss(gt_label, gt_landmark, landmark_offset_pred) total_loss = cls_loss + box_offset_loss * 0.5 + landmark_loss accuracy = self.compute_accuracy(cls_pred, gt_label) self.optimizer.zero_grad() total_loss.backward() self.optimizer.step() cls_loss_.update(cls_loss, data.size(0)) box_offset_loss_.update(box_offset_loss, data.size(0)) landmark_loss_.update(landmark_loss, data.size(0)) total_loss_.update(total_loss, data.size(0)) accuracy_.update(accuracy, data.size(0)) print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.6f}'.format( epoch, batch_idx * len(data), len(self.train_loader.dataset), 100. * batch_idx / len(self.train_loader), total_loss.item(), accuracy.item())) self.scalar_info['cls_loss'] = cls_loss_.avg self.scalar_info['box_offset_loss'] = box_offset_loss_.avg self.scalar_info['landmark_loss'] = landmark_loss_.avg self.scalar_info['total_loss'] = total_loss_.avg self.scalar_info['accuracy'] = accuracy_.avg self.scalar_info['lr'] = self.scheduler.get_lr()[0] if self.logger is not None: for tag, value in list(self.scalar_info.items()): self.logger.scalar_summary(tag, value, self.run_count) self.scalar_info = {} self.run_count += 1 print("|===>Loss: {:.4f}".format(total_loss_.avg)) return cls_loss_.avg, box_offset_loss_.avg, landmark_loss_.avg, total_loss_.avg, accuracy_.avg
class FastCNNTrainer(object): def __init__(self, lr, train_loader, valid_loader, model, optimizer, scheduler, logger, device): self.lr = lr self.train_loader = train_loader self.valid_loader = valid_loader self.model = model self.optimizer = optimizer self.scheduler = scheduler self.device = device self.lossfn = LossFn(self.device) self.logger = logger self.run_count = 0 self.scalar_info = {} def compute_accuracy(self, prob_cls, gt_cls): pred_cls = torch.max(prob_cls, 1)[1].squeeze() accuracy = float((pred_cls == gt_cls).sum()) / float(gt_cls.size(0)) return accuracy def update_lr(self, epoch): """ update learning rate of optimizers :param epoch: current training epoch """ # update learning rate of model optimizer for param_group in self.optimizer.param_groups: param_group['lr'] = self.lr def train(self, epoch): cls_loss_ = AverageMeter() accuracy_ = AverageMeter() accuracy_valid_ = AverageMeter() #训练集作为模型输入 self.scheduler.step() self.model.train() for batch_idx, (data, gt_label) in enumerate(self.train_loader): data, gt_label = data.to(self.device), gt_label.to( self.device) cls_pred, feature = self.model(data) # compute the loss cls_loss = self.lossfn.cls_loss(gt_label, cls_pred) accuracy = self.compute_accuracy(cls_pred, gt_label) self.optimizer.zero_grad() cls_loss.backward() self.optimizer.step() cls_loss_.update(cls_loss.item(), data.size(0)) accuracy_.update(accuracy, data.size(0)) if batch_idx%20 == 10: print(batch_idx) print(cls_loss.item()) # 验证集作为模型输入 with torch.no_grad(): self.model.eval() for batch_idx, (data, gt_label) in enumerate(self.valid_loader): data, gt_label = data.to(self.device), gt_label.to( self.device) cls_pred, feature = self.model(data) accuracy_valid = self.compute_accuracy(cls_pred, gt_label) accuracy_valid_.update(accuracy_valid, data.size(0)) #记录数据 self.scalar_info['cls_loss'] = cls_loss_.avg self.scalar_info['accuracy'] = accuracy_.avg self.scalar_info['lr'] = self.scheduler.get_lr()[0] # if self.logger is not None: # for tag, value in list(self.scalar_info.items()): # self.logger.scalar_summary(tag, value, self.run_count) # self.scalar_info = {} # self.run_count += 1 print("\r\nEpoch: {}|===>Train Loss: {:.8f} Train Accuracy: {:.6f} valid Accuracy: {:.6f}\r\n" .format(epoch, cls_loss_.avg, accuracy_.avg, accuracy_valid_.avg)) return cls_loss_.avg, accuracy_.avg, accuracy_valid_.avg def calculate_mean_feature(self): with torch.no_grad(): self.model.eval() featureMean = torch.zeros((8, 384*3*3)) labelNumber = torch.zeros(8) for batch_idx, (data, gt_label) in enumerate(self.valid_loader): data, gt_label = data.to(self.device), gt_label.to( self.device) cls_pred, feature = self.model(data) for s in range(data.size(0)): label = gt_label[s] featureMean[label] += feature[s] labelNumber[label] += 1 for l in range(featureMean.size(0)): featureMean[l] = featureMean[l]/labelNumber[l] return featureMean
class ONetTrainer(object): def __init__(self, lr, train_loader, model, optimizer, scheduler, logger, device): self.lr = lr self.train_loader = train_loader self.model = model self.optimizer = optimizer self.scheduler = scheduler self.device = device self.lossfn = LossFn(self.device) self.logger = logger self.run_count = 0 self.scalar_info = {} def compute_accuracy(self, prob_cls, gt_cls, prob_attr, gt_attr): #we only need the detection which >= 0 prob_cls = torch.squeeze(prob_cls) mask = torch.ge(gt_cls, 0) #get valid element valid_gt_cls = torch.masked_select(gt_cls, mask) valid_prob_cls = torch.masked_select(prob_cls, mask) size = min(valid_gt_cls.size()[0], valid_prob_cls.size()[0]) prob_ones = torch.ge(valid_prob_cls, 0.6).float() right_ones = torch.eq(prob_ones, valid_gt_cls.float()).float() accuracy_cls = torch.div(torch.mul(torch.sum(right_ones), float(1.0)), float(size)) prob_attr = torch.squeeze(prob_attr) mask_attr = torch.eq(gt_cls, -2) chose_index = torch.nonzero(mask_attr.data) chose_index = torch.squeeze(chose_index) valid_gt_attr = gt_attr[chose_index, :] valid_prob_attr = prob_attr[chose_index, :] size_attr = min(valid_gt_attr.size()[0], valid_prob_attr.size()[0]) valid_gt_color = valid_gt_attr[:,0] valid_gt_layer = valid_gt_attr[:,1] valid_gt_type = valid_gt_attr[:,2] # print(valid_prob_attr) valid_prob_color = torch.max(valid_prob_attr[:,:5],1) valid_prob_layer = torch.max(valid_prob_attr[:,5:7],1) valid_prob_type = torch.max(valid_prob_attr[:,7:],1) # print(valid_prob_color) # print(valid_gt_color) color_right_ones = torch.eq(valid_prob_color[1],valid_gt_color).float() layer_right_ones = torch.eq(valid_prob_layer[1],valid_gt_layer).float() type_right_ones = torch.eq(valid_prob_type[1], valid_gt_type).float() accuracy_color = torch.div(torch.mul(torch.sum(color_right_ones), float(1.0)), float(size_attr)) accuracy_layer = torch.div(torch.mul(torch.sum(layer_right_ones), float(1.0)), float(size_attr)) accuracy_type = torch.div(torch.mul(torch.sum(type_right_ones), float(1.0)), float(size_attr)) return accuracy_cls,accuracy_color,accuracy_layer,accuracy_type def update_lr(self, epoch): """ update learning rate of optimizers :param epoch: current training epoch """ # update learning rate of model optimizer for param_group in self.optimizer.param_groups: param_group['lr'] = self.lr def train(self, epoch): cls_loss_ = AverageMeter() box_offset_loss_ = AverageMeter() landmark_loss_ = AverageMeter() total_loss_ = AverageMeter() accuracy_cls_ = AverageMeter() accuracy_color_ = AverageMeter() accuracy_layer_ = AverageMeter() accuracy_type_ = AverageMeter() self.scheduler.step() self.model.train() for batch_idx, (data, target) in enumerate(self.train_loader): gt_label = target['label'] gt_bbox = target['bbox_target'] gt_landmark = target['landmark_target'] gt_attr = target['attribute'] data, gt_label, gt_bbox, gt_landmark, gt_attr = data.to(self.device), gt_label.to( self.device), gt_bbox.to(self.device).float(), gt_landmark.to( self.device).float(), gt_attr.to(self.device).long() cls_pred, box_offset_pred, landmark_offset_pred, attr_pred = self.model(data) # print(cls_pred[0:100]) # print(box_offset_pred[0:100,:]) # print(landmark_offset_pred[0:100,:]) # print(attr_pred[0:100,:]) # compute the loss cls_loss = self.lossfn.cls_loss(gt_label, cls_pred) box_offset_loss = self.lossfn.box_loss( gt_label, gt_bbox, box_offset_pred) landmark_loss = self.lossfn.landmark_loss(gt_label, gt_landmark, landmark_offset_pred) color_loss,layer_loss,type_loss = self.lossfn.attr_loss(gt_label,gt_attr,attr_pred) total_loss = cls_loss + box_offset_loss * 0.5 + landmark_loss + color_loss*0.5 + layer_loss*0.5 + type_loss*0.5 accuracy_cls,accuracy_color,accuracy_layer,accuracy_type = self.compute_accuracy(cls_pred, gt_label, attr_pred, gt_attr) self.optimizer.zero_grad() total_loss.backward() self.optimizer.step() cls_loss_.update(cls_loss, data.size(0)) box_offset_loss_.update(box_offset_loss, data.size(0)) landmark_loss_.update(landmark_loss, data.size(0)) total_loss_.update(total_loss, data.size(0)) accuracy_cls_.update(accuracy_cls, data.size(0)) accuracy_color_.update(accuracy_color, data.size(0)) accuracy_layer_.update(accuracy_layer, data.size(0)) accuracy_type_.update(accuracy_type, data.size(0)) print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}, cls_loss: {:.6f}, box_loss: {:.6f}, landmark_loss: {:.6f}, color_loss: {:.6f}, layer_loss: {:.6f}, type_loss: {:.6f}'.format( epoch, batch_idx * len(data), len(self.train_loader.dataset), 100. * batch_idx / len(self.train_loader), total_loss.item(), cls_loss.item(), box_offset_loss.item(), landmark_loss.item(), color_loss.item(),layer_loss.item(),type_loss.item())) print('Accuracy_cls: {:.6f}, Accuracy_color: {:.6f}, Accuracy_layer: {:.6f}, Accuracy_type: {:.6f}'.format( accuracy_cls.item(),accuracy_color.item(),accuracy_layer.item(),accuracy_type.item())) self.scalar_info['cls_loss'] = cls_loss_.avg self.scalar_info['box_offset_loss'] = box_offset_loss_.avg self.scalar_info['landmark_loss'] = landmark_loss_.avg self.scalar_info['total_loss'] = total_loss_.avg self.scalar_info['accuracy_cls'] = accuracy_cls_.avg self.scalar_info['accuracy_color'] = accuracy_color_.avg self.scalar_info['accuracy_layer'] = accuracy_layer_.avg self.scalar_info['accuracy_type'] = accuracy_type_.avg self.scalar_info['lr'] = self.scheduler.get_lr()[0] if self.logger is not None: for tag, value in list(self.scalar_info.items()): self.logger.scalar_summary(tag, value, self.run_count) self.scalar_info = {} self.run_count += 1 print("|===>Loss: {:.4f}".format(total_loss_.avg)) return cls_loss_.avg, box_offset_loss_.avg, landmark_loss_.avg, total_loss_.avg, accuracy_cls_.avg, accuracy_color_.avg, accuracy_layer_.avg, accuracy_type_.avg
class LeNet_5Trainer(object): def __init__(self, lr, train_loader, valid_x, valid_y, model, optimizer, scheduler, logger, device): self.lr = lr self.train_loader = train_loader self.valid_x = valid_x self.valid_y = valid_y self.model = model self.optimizer = optimizer self.scheduler = scheduler self.device = device self.lossfn = LossFn(self.device) self.logger = logger self.run_count = 0 self.scalar_info = {} def compute_accuracy(self, prob_cls, gt_cls): pred_cls = torch.max(prob_cls, 1)[1].squeeze() accuracy = float((pred_cls == gt_cls).sum()) / float(gt_cls.size(0)) return accuracy def train(self, epoch): cls_loss_ = AverageMeter() accuracy_ = AverageMeter() self.model.train() for batch_idx, (data, gt_label) in enumerate(self.train_loader): data, gt_label = data.to(self.device), gt_label.to(self.device) cls_pred = self.model(data) # compute the loss cls_loss = self.lossfn.cls_loss(gt_label, cls_pred) accuracy = self.compute_accuracy(cls_pred, gt_label) self.optimizer.zero_grad() cls_loss.backward() self.optimizer.step() cls_loss_.update(cls_loss, data.size(0)) accuracy_.update(accuracy, data.size(0)) if batch_idx % 50 == 0: print( 'Train Epoch: {} [{}/{} ({:.0f}%)]\tTrain Loss: {:.6f}\tTrain Accuracy: {:.6f}' .format(epoch, batch_idx * len(data), len(self.train_loader.dataset), 100. * batch_idx / len(self.train_loader), cls_loss.item(), accuracy)) self.scalar_info['cls_loss'] = cls_loss_.avg self.scalar_info['accuracy'] = accuracy_.avg self.scalar_info['lr'] = self.lr # if self.logger is not None: # for tag, value in list(self.scalar_info.items()): # self.logger.scalar_summary(tag, value, self.run_count) # self.scalar_info = {} # self.run_count += 1 print("|===>Loss: {:.4f} Train Accuracy: {:.6f} ".format( cls_loss_.avg, accuracy_.avg)) return cls_loss_.avg, accuracy_.avg