def main(args, dst_folder): # best_ac only record the best top1_ac for validation set. best_ac = 0.0 # os.environ['CUDA_VISIBLE_DEVICES'] = '0' if args.cuda_dev == 1: torch.cuda.set_device(1) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") torch.backends.cudnn.deterministic = True # fix the GPU to deterministic mode torch.manual_seed(args.seed) # CPU seed if device == "cuda": torch.cuda.manual_seed_all(args.seed) # GPU seed random.seed(args.seed) # python seed for image transformation np.random.seed(args.seed) if args.dataset == 'svhn': mean = [x/255 for x in[127.5,127.5,127.5]] std = [x/255 for x in[127.5,127.5,127.5]] elif args.dataset == 'cifar100': mean = [0.5071, 0.4867, 0.4408] std = [0.2675, 0.2565, 0.2761] if args.DA == "standard": transform_train = transforms.Compose([ transforms.Pad(2, padding_mode='reflect'), transforms.RandomCrop(32), #transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean, std), ]) elif args.DA == "jitter": transform_train = transforms.Compose([ transforms.Pad(2, padding_mode='reflect'), transforms.ColorJitter(brightness= 0.4, contrast= 0.4, saturation= 0.4, hue= 0.1), transforms.RandomCrop(32), #SVHNPolicy(), #AutoAugment(), #transforms.RandomHorizontalFlip(), transforms.ToTensor(), #Cutout(n_holes=1,length=20), transforms.Normalize(mean, std), ]) else: print("Wrong value for --DA argument.") transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean, std), ]) # data loader train_loader, test_loader, train_noisy_indexes = data_config(args, transform_train, transform_test, dst_folder) if args.network == "MT_Net": print("Loading MT_Net...") model = MT_Net(num_classes = args.num_classes, dropRatio = args.dropout).to(device) elif args.network == "WRN28_2_wn": print("Loading WRN28_2...") model = WRN28_2_wn(num_classes = args.num_classes, dropout = args.dropout).to(device) elif args.network == "PreactResNet18_WNdrop": print("Loading preActResNet18_WNdrop...") model = PreactResNet18_WNdrop(drop_val = args.dropout, num_classes = args.num_classes).to(device) print('Total params: %.2fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0)) milestones = args.M if args.swa == 'True': # to install it: # pip3 install torchcontrib # git clone https://github.com/pytorch/contrib.git # cd contrib # sudo python3 setup.py install from torchcontrib.optim import SWA #base_optimizer = RAdam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), weight_decay=1e-4) base_optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=1e-4) optimizer = SWA(base_optimizer, swa_lr=args.swa_lr) else: #optimizer = RAdam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), weight_decay=1e-4) optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=1e-4) scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.1) loss_train_epoch = [] loss_val_epoch = [] acc_train_per_epoch = [] acc_val_per_epoch = [] new_labels = [] exp_path = os.path.join('./', 'noise_models_{0}'.format(args.experiment_name), str(args.labeled_samples)) res_path = os.path.join('./', 'metrics_{0}'.format(args.experiment_name), str(args.labeled_samples)) if not os.path.isdir(res_path): os.makedirs(res_path) if not os.path.isdir(exp_path): os.makedirs(exp_path) cont = 0 load = False save = True if args.initial_epoch != 0: initial_epoch = args.initial_epoch load = True save = False if args.dataset_type == 'sym_noise_warmUp': load = False save = True if load: if args.loss_term == 'Reg_ep': train_type = 'C' if args.loss_term == 'MixUp_ep': train_type = 'M' if args.dropout > 0.0: train_type = train_type + 'drop' + str(int(10*args.dropout)) if args.beta == 0.0: train_type = train_type + 'noReg' path = './checkpoints/warmUp_{6}_{5}_{0}_{1}_{2}_{3}_S{4}.hdf5'.format(initial_epoch, \ args.dataset, \ args.labeled_samples, \ args.network, \ args.seed, \ args.Mixup_Alpha, \ train_type) checkpoint = torch.load(path) print("Load model in epoch " + str(checkpoint['epoch'])) print("Path loaded: ", path) model.load_state_dict(checkpoint['state_dict']) print("Relabeling the unlabeled samples...") model.eval() initial_rand_relab = args.label_noise results = np.zeros((len(train_loader.dataset), 10), dtype=np.float32) for images, images_pslab, labels, soft_labels, index in train_loader: images = images.to(device) labels = labels.to(device) soft_labels = soft_labels.to(device) outputs = model(images) prob, loss = loss_soft_reg_ep(outputs, labels, soft_labels, device, args) results[index.detach().numpy().tolist()] = prob.cpu().detach().numpy().tolist() train_loader.dataset.update_labels_randRelab(results, train_noisy_indexes, initial_rand_relab) print("Start training...") for epoch in range(1, args.epoch + 1): st = time.time() scheduler.step() # train for one epoch print(args.experiment_name, args.labeled_samples) loss_per_epoch, top_5_train_ac, top1_train_acc_original_labels, \ top1_train_ac, train_time = train_CrossEntropy_partialRelab(\ args, model, device, \ train_loader, optimizer, \ epoch, train_noisy_indexes) loss_train_epoch += [loss_per_epoch] # test if args.validation_exp == "True": loss_per_epoch, acc_val_per_epoch_i = validating(args, model, device, test_loader) else: loss_per_epoch, acc_val_per_epoch_i = testing(args, model, device, test_loader) loss_val_epoch += loss_per_epoch acc_train_per_epoch += [top1_train_ac] acc_val_per_epoch += acc_val_per_epoch_i #################################################################################################### ############################# SAVING MODELS ########################### #################################################################################################### if not os.path.exists('./checkpoints'): os.mkdir('./checkpoints') if epoch == 1: best_acc_val = acc_val_per_epoch_i[-1] snapBest = 'best_epoch_%d_valLoss_%.5f_valAcc_%.5f_noise_%d_bestAccVal_%.5f' % ( epoch, loss_per_epoch[-1], acc_val_per_epoch_i[-1], args.labeled_samples, best_acc_val) torch.save(model.state_dict(), os.path.join(exp_path, snapBest + '.pth')) torch.save(optimizer.state_dict(), os.path.join(exp_path, 'opt_' + snapBest + '.pth')) else: if acc_val_per_epoch_i[-1] > best_acc_val: best_acc_val = acc_val_per_epoch_i[-1] if cont > 0: try: os.remove(os.path.join(exp_path, 'opt_' + snapBest + '.pth')) os.remove(os.path.join(exp_path, snapBest + '.pth')) except OSError: pass snapBest = 'best_epoch_%d_valLoss_%.5f_valAcc_%.5f_noise_%d_bestAccVal_%.5f' % ( epoch, loss_per_epoch[-1], acc_val_per_epoch_i[-1], args.labeled_samples, best_acc_val) torch.save(model.state_dict(), os.path.join(exp_path, snapBest + '.pth')) torch.save(optimizer.state_dict(), os.path.join(exp_path, 'opt_' + snapBest + '.pth')) cont += 1 if epoch == args.epoch: snapLast = 'last_epoch_%d_valLoss_%.5f_valAcc_%.5f_noise_%d_bestValLoss_%.5f' % ( epoch, loss_per_epoch[-1], acc_val_per_epoch_i[-1], args.labeled_samples, best_acc_val) torch.save(model.state_dict(), os.path.join(exp_path, snapLast + '.pth')) torch.save(optimizer.state_dict(), os.path.join(exp_path, 'opt_' + snapLast + '.pth')) #### Save models for ensembles: if (epoch >= 150) and (epoch%2 == 0) and (args.save_checkpoint == "True"): print("Saving model ...") out_path = './checkpoints/ENS_{0}_{1}'.format(args.experiment_name, args.labeled_samples) if not os.path.exists(out_path): os.makedirs(out_path) torch.save(model.state_dict(), out_path + "/epoch_{0}.pth".format(epoch)) ### Saving model to load it again # cond = epoch%1 == 0 if args.dataset_type == 'sym_noise_warmUp': if args.loss_term == 'Reg_ep': train_type = 'C' if args.loss_term == 'MixUp_ep': train_type = 'M' if args.dropout > 0.0: train_type = train_type + 'drop' + str(int(10*args.dropout)) if args.beta == 0.0: train_type = train_type + 'noReg' cond = (epoch==args.epoch) name = 'warmUp_{1}_{0}'.format(args.Mixup_Alpha, train_type) save = True else: cond = (epoch==args.epoch) name = 'warmUp_{1}_{0}'.format(args.Mixup_Alpha, train_type) save = True if cond and save: print("Saving models...") path = './checkpoints/{0}_{1}_{2}_{3}_{4}_S{5}.hdf5'.format(name, epoch, args.dataset, args.labeled_samples, args.network, args.seed) save_checkpoint({ 'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer' : optimizer.state_dict(), 'loss_train_epoch' : np.asarray(loss_train_epoch), 'loss_val_epoch' : np.asarray(loss_val_epoch), 'acc_train_per_epoch' : np.asarray(acc_train_per_epoch), 'acc_val_per_epoch' : np.asarray(acc_val_per_epoch), 'labels': np.asarray(train_loader.dataset.soft_labels) }, filename = path) #################################################################################################### ############################ SAVING METRICS ########################### #################################################################################################### # Save losses: np.save(res_path + '/' + str(args.labeled_samples) + '_LOSS_epoch_train.npy', np.asarray(loss_train_epoch)) np.save(res_path + '/' + str(args.labeled_samples) + '_LOSS_epoch_val.npy', np.asarray(loss_val_epoch)) # save accuracies: np.save(res_path + '/' + str(args.labeled_samples) + '_accuracy_per_epoch_train.npy', np.asarray(acc_train_per_epoch)) np.save(res_path + '/' + str(args.labeled_samples) + '_accuracy_per_epoch_val.npy', np.asarray(acc_val_per_epoch)) # save the new labels new_labels.append(train_loader.dataset.labels) np.save(res_path + '/' + str(args.labeled_samples) + '_new_labels.npy', np.asarray(new_labels)) #logging.info('Epoch: [{}|{}], train_loss: {:.3f}, top1_train_ac: {:.3f}, top1_val_ac: {:.3f}, train_time: {:.3f}'.format(epoch, args.epoch, loss_per_epoch[-1], top1_train_ac, acc_val_per_epoch_i[-1], time.time() - st)) # applying swa if args.swa == 'True': optimizer.swap_swa_sgd() optimizer.bn_update(train_loader, model, device) if args.validation_exp == "True": loss_swa, acc_val_swa = validating(args, model, device, test_loader) else: loss_swa, acc_val_swa = testing(args, model, device, test_loader) snapLast = 'last_epoch_%d_valLoss_%.5f_valAcc_%.5f_noise_%d_bestValLoss_%.5f_swaAcc_%.5f' % ( epoch, loss_per_epoch[-1], acc_val_per_epoch_i[-1], args.labeled_samples, best_acc_val, acc_val_swa[0]) torch.save(model.state_dict(), os.path.join(exp_path, snapLast + '.pth')) torch.save(optimizer.state_dict(), os.path.join(exp_path, 'opt_' + snapLast + '.pth')) # save_fig(dst_folder) print('Best ac:%f' % best_acc_val) record_result(dst_folder, best_ac)
if scheduler is not None: scheduler.step(total_pesq / num_test_data) writer.add_scalar('Loss/valid', total_loss / num_test_data, epoch) # writer.add_scalar('PESQ/valid', total_pesq / num_test_data, epoch) # checkpointing curr_loss = total_loss / num_test_data if curr_loss < best_loss: best_loss = curr_loss save_path = os.path.join(ckpt_path, 'model_best.ckpt') print(f'Saving checkpoint to {save_path}') torch.save({ 'epoch': epoch, 'model_state_dict': net.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': total_loss / num_test_data, }, save_path) # curr_pesq = total_pesq / num_test_data # if curr_pesq > best_pesq: # best_pesq = curr_pesq # save_path = os.path.join(ckpt_path, 'model_best.ckpt') # print(f'Saving checkpoint to {save_path}') # torch.save({ # 'epoch': epoch, # 'model_state_dict': net.state_dict(), # 'optimizer_state_dict': optimizer.state_dict(), # 'loss': total_loss / num_test_data, # 'pesq': total_pesq / num_test_data # }, save_path)
class Trainer(): def __init__(self, config_path): self.image_config, self.model_config, self.run_config = LoadConfig( config_path=config_path).train_config() self.device = torch.device('cuda:%d' % self.run_config['device_ids'][0] if torch. cuda.is_available else 'cpu') self.model = getModel(self.model_config) os.makedirs(self.run_config['model_save_path'], exist_ok=True) self.run_config['num_workers'] = self.run_config['num_workers'] * len( self.run_config['device_ids']) self.train_set = Data(root=self.image_config['image_path'], phase='train', data_name=self.image_config['data_name'], img_mode=self.image_config['image_mode'], n_classes=self.model_config['num_classes'], size=self.image_config['image_size'], scale=self.image_config['image_scale']) self.valid_set = Data(root=self.image_config['image_path'], phase='valid', data_name=self.image_config['data_name'], img_mode=self.image_config['image_mode'], n_classes=self.model_config['num_classes'], size=self.image_config['image_size'], scale=self.image_config['image_scale']) self.className = self.valid_set.className self.train_loader = DataLoader( self.train_set, batch_size=self.run_config['batch_size'], shuffle=True, num_workers=self.run_config['num_workers'], pin_memory=True, drop_last=False) self.valid_loader = DataLoader( self.valid_set, batch_size=self.run_config['batch_size'], shuffle=True, num_workers=self.run_config['num_workers'], pin_memory=True, drop_last=False) train_params = self.model.parameters() self.optimizer = RAdam(train_params, lr=eval(self.run_config['lr']), weight_decay=eval( self.run_config['weight_decay'])) if self.run_config['swa']: self.optimizer = SWA(self.optimizer, swa_start=10, swa_freq=5, swa_lr=0.005) # 设置学习率调节策略 self.lr_scheduler = utils.adjustLR.AdjustLr(self.optimizer) if self.run_config['use_weight_balance']: weight = utils.weight_balance.getWeight( self.run_config['weights_file']) else: weight = None self.Criterion = SegmentationLosses(weight=weight, cuda=True, device=self.device, batch_average=False) self.metric = utils.metrics.MetricMeter( self.model_config['num_classes']) @logger.catch # 在日志中记录错误 def __call__(self): # 设置记录日志 self.global_name = self.model_config['model_name'] logger.add(os.path.join( self.image_config['image_path'], 'log', 'log_' + self.global_name + '/train_{time}.log'), format="{time} {level} {message}", level="INFO", encoding='utf-8') self.writer = SummaryWriter(logdir=os.path.join( self.image_config['image_path'], 'run', 'runs_' + self.global_name)) logger.info("image_config: {} \n model_config: {} \n run_config: {}", self.image_config, self.model_config, self.run_config) # 如果多余一张卡,就采用数据并行 if len(self.run_config['device_ids']) > 1: self.model = nn.DataParallel( self.model, device_ids=self.run_config['device_ids']) self.model.to(device=self.device) cnt = 0 # 如果有预训练模型就加载 if self.run_config['pretrain'] != '': logger.info("loading pretrain %s" % self.run_config['pretrain']) try: self.load_checkpoint(use_optimizer=True, use_epoch=True, use_miou=True) except: print('load model with channed!!!!!') self.load_checkpoint_with_changed(use_optimizer=False, use_epoch=False, use_miou=False) logger.info("start training") for epoch in range(self.run_config['start_epoch'], self.run_config['epoch']): lr = self.optimizer.param_groups[0]['lr'] print('epoch=%d, lr=%.8f' % (epoch, lr)) self.train_epoch(epoch, lr) valid_miou = self.valid_epoch(epoch) # 确定采用哪一种学习率调节策略 self.lr_scheduler.LambdaLR_(milestone=5, gamma=0.92).step(epoch=epoch) self.save_checkpoint(epoch, valid_miou, 'last_' + self.global_name) if valid_miou > self.run_config['best_miou']: cnt = 0 self.save_checkpoint(epoch, valid_miou, 'best_' + self.global_name) logger.info("############# %d saved ##############" % epoch) self.run_config['best_miou'] = valid_miou else: cnt += 1 if cnt == self.run_config['early_stop']: logger.info("early stop") break self.writer.close() def train_epoch(self, epoch, lr): self.metric.reset() train_loss = 0.0 train_miou = 0.0 tbar = tqdm(self.train_loader) self.model.train() for i, (image, mask, edge) in enumerate(tbar): tbar.set_description('train_miou:%.6f' % train_miou) tbar.set_postfix({"train_loss": train_loss}) image = image.to(self.device) mask = mask.to(self.device) edge = edge.to(self.device) self.optimizer.zero_grad() out = self.model(image) if isinstance(out, tuple): aux_out, final_out = out[0], out[1] else: aux_out, final_out = None, out if self.model_config['model_name'] == 'ocrnet': aux_loss = self.Criterion.build_loss(mode='rmi')(aux_out, mask) cls_loss = self.Criterion.build_loss(mode='ce')(final_out, mask) loss = 0.4 * aux_loss + cls_loss loss = loss.mean() elif self.model_config['model_name'] == 'hrnet_duc': loss_body = self.Criterion.build_loss( mode=self.run_config['loss_type'])(final_out, mask) loss_edge = self.Criterion.build_loss(mode='dice')( aux_out.squeeze(), edge) loss = loss_body + loss_edge loss = loss.mean() else: loss = self.Criterion.build_loss( mode=self.run_config['loss_type'])(final_out, mask) loss.backward() self.optimizer.step() if self.run_config['swa']: self.optimizer.swap_swa_sgd() with torch.no_grad(): train_loss = ((train_loss * i) + loss.item()) / (i + 1) _, pred = torch.max(final_out, dim=1) self.metric.add(pred.cpu().numpy(), mask.cpu().numpy()) train_miou, train_ious = self.metric.miou() train_fwiou = self.metric.fw_iou() train_accu = self.metric.pixel_accuracy() train_fwaccu = self.metric.pixel_accuracy_class() logger.info( "Epoch:%2d\t lr:%.8f\t Train loss:%.4f\t Train FWiou:%.4f\t Train Miou:%.4f\t Train accu:%.4f\t " "Train fwaccu:%.4f" % (epoch, lr, train_loss, train_fwiou, train_miou, train_accu, train_fwaccu)) cls = "" ious = list() ious_dict = OrderedDict() for i, c in enumerate(self.className): ious_dict[c] = train_ious[i] ious.append(ious_dict[c]) cls += "%s:" % c + "%.4f " ious = tuple(ious) logger.info(cls % ious) # tensorboard self.writer.add_scalar("lr", lr, epoch) self.writer.add_scalar("loss/train_loss", train_loss, epoch) self.writer.add_scalar("miou/train_miou", train_miou, epoch) self.writer.add_scalar("fwiou/train_fwiou", train_fwiou, epoch) self.writer.add_scalar("accuracy/train_accu", train_accu, epoch) self.writer.add_scalar("fwaccuracy/train_fwaccu", train_fwaccu, epoch) self.writer.add_scalars("ious/train_ious", ious_dict, epoch) def valid_epoch(self, epoch): self.metric.reset() valid_loss = 0.0 valid_miou = 0.0 tbar = tqdm(self.valid_loader) self.model.eval() with torch.no_grad(): for i, (image, mask, edge) in enumerate(tbar): tbar.set_description('valid_miou:%.6f' % valid_miou) tbar.set_postfix({"valid_loss": valid_loss}) image = image.to(self.device) mask = mask.to(self.device) edge = edge.to(self.device) out = self.model(image) if isinstance(out, tuple): aux_out, final_out = out[0], out[1] else: aux_out, final_out = None, out if self.model_config['model_name'] == 'ocrnet': aux_loss = self.Criterion.build_loss(mode='rmi')(aux_out, mask) cls_loss = self.Criterion.build_loss(mode='ce')(final_out, mask) loss = 0.4 * aux_loss + cls_loss loss = loss.mean() elif self.model_config['model_name'] == 'hrnet_duc': loss_body = self.Criterion.build_loss( mode=self.run_config['loss_type'])(final_out, mask) loss_edge = self.Criterion.build_loss(mode='dice')( aux_out.squeeze(), edge) loss = loss_body + loss_edge # loss = loss.mean() else: loss = self.Criterion.build_loss(mode='ce')(final_out, mask) valid_loss = ((valid_loss * i) + float(loss)) / (i + 1) _, pred = torch.max(final_out, dim=1) self.metric.add(pred.cpu().numpy(), mask.cpu().numpy()) valid_miou, valid_ious = self.metric.miou() valid_fwiou = self.metric.fw_iou() valid_accu = self.metric.pixel_accuracy() valid_fwaccu = self.metric.pixel_accuracy_class() logger.info( "epoch:%d\t valid loss:%.4f\t valid fwiou:%.4f\t valid miou:%.4f valid accu:%.4f\t " "valid fwaccu:%.4f\t" % (epoch, valid_loss, valid_fwiou, valid_miou, valid_accu, valid_fwaccu)) ious = list() cls = "" ious_dict = OrderedDict() for i, c in enumerate(self.className): ious_dict[c] = valid_ious[i] ious.append(ious_dict[c]) cls += "%s:" % c + "%.4f " ious = tuple(ious) logger.info(cls % ious) self.writer.add_scalar("loss/valid_loss", valid_loss, epoch) self.writer.add_scalar("miou/valid_miou", valid_miou, epoch) self.writer.add_scalar("fwiou/valid_fwiou", valid_fwiou, epoch) self.writer.add_scalar("accuracy/valid_accu", valid_accu, epoch) self.writer.add_scalar("fwaccuracy/valid_fwaccu", valid_fwaccu, epoch) self.writer.add_scalars("ious/valid_ious", ious_dict, epoch) return valid_miou def save_checkpoint(self, epoch, best_miou, flag): meta = { 'epoch': epoch, 'model': self.model.state_dict(), 'optim': self.optimizer.state_dict(), 'bmiou': best_miou } try: torch.save(meta, os.path.join(self.run_config['model_save_path'], '%s.pth' % flag), _use_new_zipfile_serialization=False) except: torch.save( meta, os.path.join(self.run_config['model_save_path'], '%s.pth' % flag)) def load_checkpoint(self, use_optimizer, use_epoch, use_miou): state_dict = torch.load(self.run_config['pretrain'], map_location=self.device) self.model.load_state_dict(state_dict['model']) if use_optimizer: self.optimizer.load_state_dict(state_dict['optim']) if use_epoch: self.run_config['start_epoch'] = state_dict['epoch'] + 1 if use_miou: self.run_config['best_miou'] = state_dict['bmiou'] def load_checkpoint_with_changed(self, use_optimizer, use_epoch, use_miou): state_dict = torch.load(self.run_config['pretrain'], map_location=self.device) pretrain_dict = state_dict['model'] model_dict = self.model.state_dict() pretrain_dict = { k: v for k, v in pretrain_dict.items() if k in model_dict and 'edge' not in k } model_dict.update(pretrain_dict) self.model.load_state_dict(model_dict) if use_optimizer: self.optimizer.load_state_dict(state_dict['optim']) if use_epoch: self.run_config['start_epoch'] = state_dict['epoch'] + 1 if use_miou: self.run_config['best_miou'] = state_dict['bmiou']
class Optimizer: optimizer_cls = None optimizer = None parameters = None def __init__(self, gradient_clipping, swa_start=None, swa_freq=None, swa_lr=None, **kwargs): self.gradient_clipping = gradient_clipping self.optimizer_kwargs = kwargs self.swa_start = swa_start self.swa_freq = swa_freq self.swa_lr = swa_lr def set_parameters(self, parameters): self.parameters = tuple(parameters) self.optimizer = self.optimizer_cls(self.parameters, **self.optimizer_kwargs) if self.swa_start is not None: from torchcontrib.optim import SWA assert self.swa_freq is not None, self.swa_freq assert self.swa_lr is not None, self.swa_lr self.optimizer = SWA(self.optimizer, swa_start=self.swa_start, swa_freq=self.swa_freq, swa_lr=self.swa_lr) def check_if_set(self): assert self.optimizer is not None, \ 'The optimizer is not initialized, call set_parameter before' \ ' using any of the optimizer functions' def zero_grad(self): self.check_if_set() return self.optimizer.zero_grad() def step(self): self.check_if_set() return self.optimizer.step() def swap_swa_sgd(self): self.check_if_set() from torchcontrib.optim import SWA assert isinstance(self.optimizer, SWA), self.optimizer return self.optimizer.swap_swa_sgd() def clip_grad(self): self.check_if_set() # Todo: report clipped and unclipped # Todo: allow clip=None but still report grad_norm grad_clips = self.gradient_clipping return torch.nn.utils.clip_grad_norm_(self.parameters, grad_clips) def to(self, device): if device is None: return self.check_if_set() for state in self.optimizer.state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = v.to(device) def cpu(self): return self.to('cpu') def cuda(self, device=None): assert device is None or isinstance(device, int), device if device is None: device = torch.device('cuda') return self.to(device) def load_state_dict(self, state_dict): self.check_if_set() return self.optimizer.load_state_dict(state_dict) def state_dict(self): self.check_if_set() return self.optimizer.state_dict()
# Call print("Starting model training....") n_epochs = setting_dict['epochs'] lr_patience = setting_dict['optimizer']['sheduler']['patience'] lr_factor = setting_dict['optimizer']['sheduler']['factor'] if weight_path is None: best_epoch = train(model,dataloaders,objective,optimizer,n_epochs,Path_list[1],Path_list[2], lr_patience=lr_patience,lr_factor=lr_factor, dice = False,seperate_loss=False, adabn = setting_dict["data"]["adabn_train"], own_sheduler = (not setting_dict["optimizer"]["longshedule"])) else: optimizer.load_state_dict(torch.load(weight_path)["optimizer"]) best_epoch = train(model,dataloaders,objective,optimizer,n_epochs-torch.load(weight_path)["epoch"],Path_list[1],Path_list[2],start_epoch = torch.load(weight_path)["epoch"]+1, loss_dict=torch.load(weight_path)["loss_dict"], lr_patience=lr_patience,lr_factor=lr_factor, dice = False,seperate_loss=False, adabn = setting_dict["data"]["adabn_train"], own_sheduler = (not setting_dict["optimizer"]["longshedule"])) print("model training finished! yey!") if optimizer_name == "SWA": print ("Updating batch norm pars for SWA") train_dataset.dataset.SWA = True SWA_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=cpu_count) optimizer.swap_swa_sgd() optimizer.bn_update(SWA_loader, model, device='cuda') state = { 'epoch': n_epochs, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'loss_dict': {} } torch.save(state, os.path.join(Path_list[2],'weights_SWA.pt'))
progress["val_loss"].append(np.mean(val_loss)) progress["val_iou"].append(iou) progress["val_dice"].append(dice) progress["val_hausdorff"].append(hausdorff) progress["val_assd"].append(assd) dict2df(progress, args.output_dir + 'progress.csv') scheduler_step(optimizer, scheduler, iou, args) # --------------------------------------------------------------------------------------------------------------- # # --------------------------------------------------------------------------------------------------------------- # # --------------------------------------------------------------------------------------------------------------- # if args.apply_swa: torch.save(optimizer.state_dict(), args.output_dir + "/optimizer_" + args.model_name + "_before_swa_swap.pt") optimizer.swap_swa_sgd() # Set the weights of your model to their SWA averages optimizer.bn_update(train_loader, model, device='cuda') torch.save( model.state_dict(), args.output_dir + "/swa_checkpoint_last_bn_update_{}epochs_lr{}.pt".format(args.epochs, args.swa_lr) ) iou, dice, hausdorff, assd, val_loss, stats = val_step( val_loader, model, criterion, weights_criterion, multiclass_criterion, args.binary_threshold, generate_stats=True, generate_overlays=args.eval_overlays, save_path=os.path.join(args.output_dir, "swa_preds") ) print("[SWA] Val IOU: %s, Val Dice: %s" % (iou, dice))
def main(args): best_ac = 0.0 ##################### # Initializing seeds and preparing GPU if args.cuda_dev == 1: torch.cuda.set_device(1) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") torch.backends.cudnn.deterministic = True # fix the GPU to deterministic mode torch.manual_seed(args.seed) # CPU seed if device == "cuda": torch.cuda.manual_seed_all(args.seed) # GPU seed random.seed(args.seed) # python seed for image transformation np.random.seed(args.seed) ##################### if args.dataset == 'cifar10': mean = [0.4914, 0.4822, 0.4465] std = [0.2023, 0.1994, 0.2010] elif args.dataset == 'cifar100': mean = [0.5071, 0.4867, 0.4408] std = [0.2675, 0.2565, 0.2761] elif args.dataset == 'miniImagenet': mean = [0.4728, 0.4487, 0.4031] std = [0.2744, 0.2663 , 0.2806] if args.DA == "standard": transform_train = transforms.Compose([ transforms.Pad(6, padding_mode='reflect'), transforms.RandomCrop(84), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean, std), ]) elif args.DA == "jitter": transform_train = transforms.Compose([ transforms.Pad(6, padding_mode='reflect'), transforms.ColorJitter(brightness= 0.4, contrast= 0.4, saturation= 0.4, hue= 0.1), transforms.RandomCrop(84), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean, std), ]) else: print("Wrong value for --DA argument.") transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean, std), ]) # data lodaer train_loader, test_loader, unlabeled_indexes = data_config(args, transform_train, transform_test) if args.network == "TE_Net": print("Loading TE_Net...") model = TE_Net(num_classes = args.num_classes).to(device) elif args.network == "MT_Net": print("Loading MT_Net...") model = MT_Net(num_classes = args.num_classes).to(device) elif args.network == "resnet18": print("Loading Resnet18...") model = resnet18(num_classes = args.num_classes).to(device) elif args.network == "resnet18_wndrop": print("Loading Resnet18...") model = resnet18_wndrop(num_classes = args.num_classes).to(device) print('Total params: {:.2f} M'.format((sum(p.numel() for p in model.parameters()) / 1000000.0))) milestones = args.M if args.swa == 'True': # to install it: # pip3 install torchcontrib # git clone https://github.com/pytorch/contrib.git # cd contrib # sudo python3 setup.py install from torchcontrib.optim import SWA base_optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.wd) optimizer = SWA(base_optimizer, swa_lr=args.swa_lr) else: optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.wd) scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.1) loss_train_epoch = [] loss_val_epoch = [] acc_train_per_epoch = [] acc_val_per_epoch = [] new_labels = [] exp_path = os.path.join('./', 'ssl_models_{0}'.format(args.experiment_name), str(args.labeled_samples)) res_path = os.path.join('./', 'metrics_{0}'.format(args.experiment_name), str(args.labeled_samples)) if not os.path.isdir(res_path): os.makedirs(res_path) if not os.path.isdir(exp_path): os.makedirs(exp_path) cont = 0 load = False save = True if args.load_epoch != 0: load_epoch = args.load_epoch load = True save = False if args.dataset_type == 'ssl_warmUp': load = False save = True if load: if args.loss_term == 'Reg_ep': train_type = 'C' if args.loss_term == 'MixUp_ep': train_type = 'M' path = './checkpoints/warmUp_{0}_{1}_{2}_{3}_{4}_{5}_S{6}.hdf5'.format(train_type, \ args.Mixup_Alpha, \ load_epoch, \ args.dataset, \ args.labeled_samples, \ args.network, \ args.seed) checkpoint = torch.load(path) print("Load model in epoch " + str(checkpoint['epoch'])) print("Path loaded: ", path) model.load_state_dict(checkpoint['state_dict']) print("Relabeling the unlabeled samples...") model.eval() results = np.zeros((len(train_loader.dataset), args.num_classes), dtype=np.float32) for images, images_pslab, labels, soft_labels, index in train_loader: images = images.to(device) labels = labels.to(device) soft_labels = soft_labels.to(device) outputs = model(images) prob, loss = loss_soft_reg_ep(outputs, labels, soft_labels, device, args) results[index.detach().numpy().tolist()] = prob.cpu().detach().numpy().tolist() train_loader.dataset.update_labels_randRelab(results, unlabeled_indexes, args.label_noise) print("Start training...") #################################################################################################### ############################### TRAINING ############################## #################################################################################################### for epoch in range(1, args.epoch + 1): st = time.time() scheduler.step() # train for one epoch print(args.experiment_name, args.labeled_samples) loss_per_epoch_train, \ top_5_train_ac, \ top1_train_acc_original_labels,\ top1_train_ac, \ train_time = train_CrossEntropy_partialRelab(args, model, device, \ train_loader, optimizer, \ epoch, unlabeled_indexes) loss_train_epoch += [loss_per_epoch_train] loss_per_epoch_test, acc_val_per_epoch_i = testing(args, model, device, test_loader) loss_val_epoch += loss_per_epoch_test acc_train_per_epoch += [top1_train_ac] acc_val_per_epoch += acc_val_per_epoch_i #################################################################################################### ############################# SAVING MODELS ########################### #################################################################################################### if not os.path.exists('./checkpoints'): os.mkdir('./checkpoints') if epoch == 1: best_acc_val = acc_val_per_epoch_i[-1] snapBest = 'best_epoch_%d_valLoss_%.5f_valAcc_%.5f_labels_%d_bestAccVal_%.5f' % ( epoch, loss_per_epoch_test[-1], acc_val_per_epoch_i[-1], args.labeled_samples, best_acc_val) torch.save(model.state_dict(), os.path.join(exp_path, snapBest + '.pth')) torch.save(optimizer.state_dict(), os.path.join(exp_path, 'opt_' + snapBest + '.pth')) else: if acc_val_per_epoch_i[-1] > best_acc_val: best_acc_val = acc_val_per_epoch_i[-1] if cont > 0: try: os.remove(os.path.join(exp_path, 'opt_' + snapBest + '.pth')) os.remove(os.path.join(exp_path, snapBest + '.pth')) except OSError: pass snapBest = 'best_epoch_%d_valLoss_%.5f_valAcc_%.5f_labels_%d_bestAccVal_%.5f' % ( epoch, loss_per_epoch_test[-1], acc_val_per_epoch_i[-1], args.labeled_samples, best_acc_val) torch.save(model.state_dict(), os.path.join(exp_path, snapBest + '.pth')) torch.save(optimizer.state_dict(), os.path.join(exp_path, 'opt_' + snapBest + '.pth')) cont += 1 if epoch == args.epoch: snapLast = 'last_epoch_%d_valLoss_%.5f_valAcc_%.5f_labels_%d_bestValLoss_%.5f' % ( epoch, loss_per_epoch_test[-1], acc_val_per_epoch_i[-1], args.labeled_samples, best_acc_val) torch.save(model.state_dict(), os.path.join(exp_path, snapLast + '.pth')) torch.save(optimizer.state_dict(), os.path.join(exp_path, 'opt_' + snapLast + '.pth')) ### Saving model to load it again # cond = epoch%1 == 0 if args.dataset_type == 'ssl_warmUp': if args.loss_term == 'Reg_ep': train_type = 'C' if args.loss_term == 'MixUp_ep': train_type = 'M' cond = (epoch==args.epoch) name = 'warmUp_{1}_{0}'.format(args.Mixup_Alpha, train_type) save = True else: cond = (epoch==args.epoch) name = 'warmUp_{1}_{0}'.format(args.Mixup_Alpha, train_type) save = True #print(cond) #print(save) if cond and save: print("Saving models...") path = './checkpoints/{0}_{1}_{2}_{3}_{4}_S{5}.hdf5'.format(name, epoch, args.dataset, \ args.labeled_samples, \ args.network, \ args.seed) save_checkpoint({ 'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer' : optimizer.state_dict(), 'loss_train_epoch' : np.asarray(loss_train_epoch), 'loss_val_epoch' : np.asarray(loss_val_epoch), 'acc_train_per_epoch' : np.asarray(acc_train_per_epoch), 'acc_val_per_epoch' : np.asarray(acc_val_per_epoch), 'labels': np.asarray(train_loader.dataset.soft_labels) }, filename = path) #################################################################################################### ############################ SAVING METRICS ########################### #################################################################################################### # Save losses: np.save(res_path + '/' + str(args.labeled_samples) + '_LOSS_epoch_train.npy', np.asarray(loss_train_epoch)) np.save(res_path + '/' + str(args.labeled_samples) + '_LOSS_epoch_val.npy', np.asarray(loss_val_epoch)) # save accuracies: np.save(res_path + '/' + str(args.labeled_samples) + '_accuracy_per_epoch_train.npy',np.asarray(acc_train_per_epoch)) np.save(res_path + '/' + str(args.labeled_samples) + '_accuracy_per_epoch_val.npy', np.asarray(acc_val_per_epoch)) # applying swa if args.swa == 'True': optimizer.swap_swa_sgd() optimizer.bn_update(train_loader, model, device) loss_swa, acc_val_swa = testing(args, model, device, test_loader) snapLast = 'last_epoch_%d_valLoss_%.5f_valAcc_%.5f_labels_%d_bestValLoss_%.5f_swaAcc_%.5f' % ( epoch, loss_per_epoch_test[-1], acc_val_per_epoch_i[-1], args.labeled_samples, best_acc_val, acc_val_swa[0]) torch.save(model.state_dict(), os.path.join(exp_path, snapLast + '.pth')) torch.save(optimizer.state_dict(), os.path.join(exp_path, 'opt_' + snapLast + '.pth')) print('Best ac:%f' % best_acc_val)
def train_model(cfg, run_id, save_dir, use_cuda, args, writer): shuffle = True print("Run ID : " + args.run_id) print("Parameters used : ") print("batch_size: " + str(args.batch_size)) print("lr: " + str(args.learning_rate)) print("loss weights: " + str(params.weights)) if args.random_skip: skip = [x for x in range(0, 4)] else: skip = [args.skip] train_data_gen = Dataset(cfg, args.input_type, 'training', 1.0, args.num_clips, skip, add_background=args.add_background) train_dataloader = DataLoader(train_data_gen, batch_size=args.batch_size, shuffle=shuffle, num_workers=args.num_workers, collate_fn=lambda b:filter_none(b, args.num_clips, args.varied_length)) print("Number of training samples : " + str(len(train_data_gen))) steps_per_epoch = len(train_data_gen) / args.batch_size print("Steps per epoch: " + str(steps_per_epoch)) if args.add_background: num_classes = cfg.num_classes + 1 else: num_classes = cfg.num_classes assert args.num_clips > 1 model = build_model(args.model_version, args.num_clips, num_classes, args.feature_dim, args.hidden_dim, args.num_layers) num_gpus = len(args.gpu.split(',')) if num_gpus > 1: model = torch.nn.DataParallel(model) if use_cuda: model.cuda() if args.optimizer == 'ADAM': print("Using ADAM optimizer") optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay) elif args.optimizer == 'SGD': print("Using SGD optimizer") optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=args.weight_decay) scheduler = MultiStepLR(optimizer, milestones=[40, 80, 120, 160], gamma=0.5) if args.swa_start > 0: optimizer = SWA(optimizer) criterion = BCEWithLogitsLoss() max_fmap_score, fmap_score = 0, 0 # loop for each epoch for epoch in range(args.num_epochs): model = train_epoch(cfg, run_id, epoch, train_dataloader, model, num_classes, optimizer, criterion, writer, use_cuda, args, weights=None, accumulation_steps=args.steps) if args.dataset in ['charades']: validation_interval = 10 if epoch > 20: validation_interval = 5 else: validation_interval = 50 if epoch > 1000: validation_interval = 10 if epoch % validation_interval == 0: fmap_score = val_epoch(cfg, epoch, model, writer, use_cuda, args) if fmap_score > max_fmap_score: for f in os.listdir(save_dir): os.remove(os.path.join(save_dir, f)) save_file_path = os.path.join(save_dir, 'model_{}_{:.4f}.pth'.format(epoch, fmap_score)) states = { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), } torch.save(states, save_file_path) max_fmap_score = fmap_score
class Trainer(object): def __init__(self, args, train_dataloader=None, validate_dataloader=None, test_dataloader=None): self.args = args self.train_dataloader = train_dataloader self.validate_dataloader = validate_dataloader self.test_dataloader = test_dataloader self.label_lst = [i for i in range(self.args.num_classes)] self.num_labels = self.args.num_classes self.config_class = AutoConfig self.model_class = BertForSequenceClassification self.config = self.config_class.from_pretrained( self.args.bert_model_name, num_labels=self.num_labels, finetuning_task='nsmc', id2label={str(i): label for i, label in enumerate(self.label_lst)}, label2id={label: i for i, label in enumerate(self.label_lst)}) self.model = self.model_class.from_pretrained( self.args.bert_model_name, config=self.config) self.optimizer = None self.scheduler = None # GPU or CPU self.device = "cuda" if torch.cuda.is_available( ) and args.cuda else "cpu" self.model.to(self.device) def train(self, alpha, gamma): train_dataloader = self.train_dataloader t_total = len(train_dataloader) * self.args.num_epochs # Prepare optimizer and schedule (linear warmup and decay) no_decay = ['bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [ p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay) ], 'weight_decay': 0.0 }, { 'params': [ p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay) ], 'weight_decay': 0.0 }] if self.args.use_swa: base_opt = AdamW(optimizer_grouped_parameters, lr=self.args.lr, eps=1e-8) self.optimizer = SWA(base_opt, swa_start=4 * len(train_dataloader), swa_freq=100, swa_lr=5e-5) self.optimizer.param_groups = self.optimizer.optimizer.param_groups self.optimizer.state = self.optimizer.optimizer.state self.optimizer.defaults = self.optimizer.optimizer.defaults else: self.optimizer = optimizer = AdamW(optimizer_grouped_parameters, lr=self.args.lr, eps=1e-8) self.scheduler = scheduler = get_linear_schedule_with_warmup( self.optimizer, num_warmup_steps=100, num_training_steps=self.args.num_epochs * len(train_dataloader)) self.criterion = FocalLoss(alpha=alpha, gamma=gamma) # Train! logger.info("***** Running training *****") logger.info(" Num examples = %d", len(self.train_dataloader) * self.args.batch_size) logger.info(" Num Epochs = %d", self.args.num_epochs) logger.info(" Total train batch size = %d", self.args.batch_size) logger.info(" Total optimization steps = %d", t_total) global_step = 0 tr_loss = 0.0 self.model.zero_grad() self.optimizer.zero_grad() train_iterator = trange(int(self.args.num_epochs), desc="Epoch") fin_result = None f1_max = 0.0 self.model.train() for epoch in train_iterator: epoch_iterator = tqdm(train_dataloader, desc="Iteration") for step, batch in enumerate(epoch_iterator): batch = tuple(t.to(self.device) for t in batch) # GPU or CPU inputs = { 'input_ids': batch[0], 'attention_mask': batch[1], 'labels': batch[3], 'token_type_ids': batch[2] } # outputs = self.model(**inputs) # loss = outputs[0] # # Custom Loss loss, logits = self.model(**inputs) logits = torch.sigmoid(logits) labels = torch.zeros( (len(batch[3]), self.num_labels)).to(self.device) labels[range(len(batch[3])), batch[3]] = 1 loss = self.criterion(logits, labels) loss.backward() self.optimizer.step() self.scheduler.step() # Update learning rate schedule self.model.zero_grad() self.optimizer.zero_grad() tr_loss += loss.item() global_step += 1 logger.info('train loss %f', loss.item()) logger.info('total train loss %f', tr_loss / global_step) if epoch >= 4 and self.args.use_swa: self.optimizer.swap_swa_sgd() fin_result = self.evaluate("validate") self.save_model(epoch) self.model.train() if epoch >= 4 and self.args.use_swa: self.optimizer.swap_swa_sgd() f1_max = max(fin_result['f1_macro'], f1_max) if epoch >= 4 and self.args.use_swa: self.optimizer.swap_swa_sgd() with open(os.path.join(self.args.base_dir, self.args.result_dir, self.args.train_id, 'param_seach.txt'), "a", encoding="utf-8") as f: f.write('alpha: {}, gamma: {}, f1_macro: {}\n'.format( alpha, gamma, f1_max)) return f1_max def evaluate(self, mode='test'): if mode == 'test': dataloader = self.test_dataloader elif mode == 'validate': dataloader = self.validate_dataloader else: raise Exception("Only dev and test dataset available") # Eval! logger.info("***** Running evaluation on %s dataset *****", mode) logger.info(" Num examples = %d", len(dataloader) * self.args.batch_size) logger.info(" Batch size = %d", self.args.batch_size) eval_loss = 0.0 nb_eval_steps = 0 preds = None out_label_ids = None self.model.eval() for batch in tqdm(dataloader, desc="Evaluating"): batch = tuple(t.to(self.device) for t in batch) with torch.no_grad(): inputs = { 'input_ids': batch[0], 'attention_mask': batch[1], 'labels': batch[3], 'token_type_ids': batch[2] } outputs = self.model(**inputs) tmp_eval_loss, logits = outputs[:2] eval_loss += tmp_eval_loss.mean().item() nb_eval_steps += 1 if preds is None: preds = logits.detach().cpu().numpy() out_label_ids = inputs['labels'].detach().cpu().numpy() else: preds = np.append(preds, logits.detach().cpu().numpy(), axis=0) out_label_ids = np.append( out_label_ids, inputs['labels'].detach().cpu().numpy(), axis=0) eval_loss = eval_loss / nb_eval_steps results = {"loss": eval_loss} preds = np.argmax(preds, axis=1) result = compute_metrics(preds, out_label_ids) results.update(result) p_macro, r_macro, f_macro, support_macro \ = precision_recall_fscore_support(y_true=out_label_ids, y_pred=preds, labels=[i for i in range(self.num_labels)], average='macro') results.update({ 'precision': p_macro, 'recall': r_macro, 'f1_macro': f_macro }) with open(self.args.prediction_file, "w", encoding="utf-8") as f: for pred in preds: f.write("{}\n".format(pred)) if mode == 'validate': logger.info("***** Eval results *****") for key in sorted(results.keys()): logger.info(" %s = %s", key, str(results[key])) return results def save_model(self, num=0): state = { 'model': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'scheduler': self.scheduler.state_dict() } torch.save( state, os.path.join(self.args.base_dir, self.args.result_dir, self.args.train_id, 'epoch_' + str(num) + '.pth')) logger.info('model saved') def load_model(self, model_name): state = torch.load( os.path.join(self.args.base_dir, self.args.result_dir, self.args.train_id, model_name)) self.model.load_state_dict(state['model']) if self.optimizer is not None: self.optimizer.load_state_dict(state['optimizer']) if self.scheduler is not None: self.scheduler.load_state_dict(state['scheduler']) logger.info('model loaded')
progress["train_loss"].append(np.mean(train_loss)) progress["val_loss"].append(np.mean(val_loss)) progress["val_accuracy"].append(accuracy) dict2df(progress, args.output_dir + 'progress.csv') scheduler_step(optimizer, scheduler, accuracy, args) # --------------------------------------------------------------------------------------------------------------- # # --------------------------------------------------------------------------------------------------------------- # # --------------------------------------------------------------------------------------------------------------- # if args.apply_swa: torch.save( optimizer.state_dict(), args.output_dir + "/optimizer_" + args.model_name + "_before_swa_swap.pt") optimizer.swap_swa_sgd( ) # Set the weights of your model to their SWA averages optimizer.bn_update(train_loader, model, device='cuda') torch.save( model.state_dict(), args.output_dir + "/swa_checkpoint_last_bn_update_{}epochs_lr{}.pt".format( args.epochs, args.swa_lr)) accuracy, val_loss = val_step_accuracy(val_loader, model, criterion, weights_criterion, multiclass_criterion,
def main(): maxIOU = 0.0 assert torch.cuda.is_available() torch.backends.cudnn.benchmark = True model_fname = '../data/model_swa_8/deeplabv3_{0}_epoch%d.pth'.format( 'crops') focal_loss = FocalLoss2d() train_dataset = CropSegmentation(train=True, crop_size=args.crop_size) # test_dataset = CropSegmentation(train=False, crop_size=args.crop_size) model = torchvision.models.segmentation.deeplabv3_resnet50( pretrained=False, progress=True, num_classes=5, aux_loss=True) if args.train: weight = np.ones(4) weight[2] = 5 weight[3] = 5 w = torch.FloatTensor(weight).cuda() criterion = nn.CrossEntropyLoss() #ignore_index=255 weight=w model = nn.DataParallel(model).cuda() for param in model.parameters(): param.requires_grad = True optimizer1 = optim.SGD(model.parameters(), lr=config.lr, momentum=0.9, weight_decay=1e-4) scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=(args.epochs // 9) + 1) optimizer = SWA(optimizer1) dataset_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=args.train, pin_memory=True, num_workers=args.workers) max_iter = args.epochs * len(dataset_loader) losses = AverageMeter() start_epoch = 0 if args.resume: if os.path.isfile(args.resume): print('=> loading checkpoint {0}'.format(args.resume)) checkpoint = torch.load(args.resume) start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print('=> loaded checkpoint {0} (epoch {1})'.format( args.resume, checkpoint['epoch'])) else: print('=> no checkpoint found at {0}'.format(args.resume)) for epoch in range(start_epoch, args.epochs): scheduler.step(epoch) model.train() for i, (inputs, target) in enumerate(dataset_loader): inputs = Variable(inputs.cuda()) target = Variable(target.cuda()) outputs = model(inputs) loss1 = focal_loss(outputs['out'], target) loss2 = focal_loss(outputs['aux'], target) loss01 = loss1 + 0.1 * loss2 loss3 = lovasz_softmax(outputs['out'], target) loss4 = lovasz_softmax(outputs['aux'], target) loss02 = loss3 + 0.1 * loss4 loss = loss01 + loss02 if np.isnan(loss.item()) or np.isinf(loss.item()): pdb.set_trace() losses.update(loss.item(), args.batch_size) loss.backward() optimizer.step() optimizer.zero_grad() if i > 10 and i % 5 == 0: optimizer.update_swa() print('epoch: {0}\t' 'iter: {1}/{2}\t' 'loss: {loss.val:.4f} ({loss.ema:.4f})'.format( epoch + 1, i + 1, len(dataset_loader), loss=losses)) if epoch > 5: torch.save( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), }, model_fname % (epoch + 1)) optimizer.swap_swa_sgd() torch.save( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), }, model_fname % (665 + 1))
def main(): if args.config_path: if args.config_path in CONFIG_TREATER: load_path = CONFIG_TREATER[args.config_path] elif args.config_path.endswith(".yaml"): load_path = args.config_path else: load_path = "experiments/" + CONFIG_TREATER[ args.config_path] + ".yaml" with open(load_path, 'rb') as fp: config = CfgNode.load_cfg(fp) else: config = None torch.backends.cudnn.deterministic = False torch.backends.cudnn.benchmark = True test_model = None max_epoch = config.TRAIN.NUM_EPOCHS print('data folder: ', args.data_folder) torch.backends.cudnn.benchmark = True # WORLD_SIZE Generated by torch.distributed.launch.py #num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 #is_distributed = num_gpus > 1 #if is_distributed: # torch.cuda.set_device(args.local_rank) # torch.distributed.init_process_group( # backend="nccl", init_method="env://", # ) model = get_model(config) model_loss = ModelLossWraper( model, config.TRAIN.CLASS_WEIGHTS, config.MODEL.IS_DISASTER_PRED, config.MODEL.IS_SPLIT_LOSS, ).cuda() #if args.local_rank == 0: #from IPython import embed; embed() #if is_distributed: # model_loss = nn.SyncBatchNorm.convert_sync_batchnorm(model_loss) # model_loss = nn.parallel.DistributedDataParallel( # model_loss#, device_ids=[args.local_rank], output_device=args.local_rank # ) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if torch.cuda.device_count() > 1: model_loss = nn.DataParallel(model_loss) model_loss.to(device) cpucount = multiprocessing.cpu_count() if config.mode.startswith("single"): trainset_loaders = {} loader_len = 0 for disaster in disaster_list[config.mode[6:]]: trainset = XView2Dataset(args.data_folder, rgb_bgr='rgb', preprocessing={ 'flip': True, 'scale': config.TRAIN.MULTI_SCALE, 'crop': config.TRAIN.CROP_SIZE, }, mode="singletrain", single_disaster=disaster) if len(trainset) > 0: train_sampler = None trainset_loader = torch.utils.data.DataLoader( trainset, batch_size=config.TRAIN.BATCH_SIZE_PER_GPU, shuffle=train_sampler is None, pin_memory=True, drop_last=True, sampler=train_sampler, num_workers=cpucount if cpucount < 16 else cpucount // 3) trainset_loaders[disaster] = trainset_loader loader_len += len(trainset_loader) print("added disaster {} with {} samples".format( disaster, len(trainset))) else: print("skipping disaster ", disaster) else: trainset = XView2Dataset(args.data_folder, rgb_bgr='rgb', preprocessing={ 'flip': True, 'scale': config.TRAIN.MULTI_SCALE, 'crop': config.TRAIN.CROP_SIZE, }, mode=config.mode) #if is_distributed: # train_sampler = DistributedSampler(trainset) #else: train_sampler = None trainset_loader = torch.utils.data.DataLoader( trainset, batch_size=config.TRAIN.BATCH_SIZE_PER_GPU, shuffle=train_sampler is None, pin_memory=True, drop_last=True, sampler=train_sampler, num_workers=multiprocessing.cpu_count()) loader_len = len(trainset_loader) model.train() lr_init = config.TRAIN.LR optimizer = torch.optim.SGD( [{ 'params': filter(lambda p: p.requires_grad, model.parameters()), 'lr': lr_init }], lr=lr_init, momentum=0.9, weight_decay=0., nesterov=False, ) num_iters = max_epoch * loader_len if config.SWA: swa_start = num_iters optimizer = SWA( optimizer, swa_start=swa_start, swa_freq=4 * loader_len, swa_lr=0.001 ) #SWA(optimizer, swa_start = None, swa_freq = None, swa_lr = None)# #scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, 0.0001, 0.05, step_size_up=1, step_size_down=2*len(trainset_loader)-1, mode='triangular', gamma=1.0, scale_fn=None, scale_mode='cycle', cycle_momentum=False, base_momentum=0.8, max_momentum=0.9, last_epoch=-1) lr = 0.0001 #model.load_state_dict(torch.load("ckpt/dual-hrnet/hrnet_450", map_location='cpu')['state_dict']) #print("weights loaded") max_epoch = max_epoch + 40 start_epoch = 0 losses = AverageMeter() model.train() cur_iters = 0 if start_epoch == 0 else None for epoch in range(start_epoch, max_epoch): if config.mode.startswith("single"): all_batches = [] total_len = 0 for disaster in sorted(list(trainset_loaders.keys())): all_batches += [ (disaster, idx) for idx in range(len(trainset_loaders[disaster])) ] total_len += len(trainset_loaders[disaster].dataset) all_batches = random.sample(all_batches, len(all_batches)) iterators = { disaster: iter(trainset_loaders[disaster]) for disaster in trainset_loaders.keys() } if cur_iters is not None: cur_iters += len(all_batches) else: cur_iters = epoch * len(all_batches) for i, (disaster, idx) in enumerate(all_batches): lr = optimizer.param_groups[0]['lr'] if not config.SWA or epoch < swa_start: lr = adjust_learning_rate(optimizer, lr_init, num_iters, i + cur_iters) samples = next(iterators[disaster]) inputs_pre = samples['pre_img'].to(device) inputs_post = samples['post_img'].to(device) target = samples['mask_img'].to(device) #disaster_target = samples['disaster'].to(device) loss = model_loss(inputs_pre, inputs_post, target) #, disaster_target) loss_sum = torch.sum(loss).detach().cpu() if np.isnan(loss_sum) or np.isinf(loss_sum): print('check') losses.update(loss_sum, 4) # batch size loss = torch.sum(loss) loss.backward() optimizer.step() optimizer.zero_grad() if args.local_rank == 0 and i % 10 == 0: logger.info('epoch: {0}\t' 'iter: {1}/{2}\t' 'lr: {3:.6f}\t' 'loss: {loss.val:.4f} ({loss.ema:.4f})\t' 'disaster: {dis}'.format(epoch + 1, i + 1, len(all_batches), lr, loss=losses, dis=disaster)) del iterators else: cur_iters = epoch * len(trainset_loader) for i, samples in enumerate(trainset_loader): lr = optimizer.param_groups[0]['lr'] if not config.SWA or epoch < swa_start: lr = adjust_learning_rate(optimizer, lr_init, num_iters, i + cur_iters) inputs_pre = samples['pre_img'].to(device) inputs_post = samples['post_img'].to(device) target = samples['mask_img'].to(device) #disaster_target = samples['disaster'].to(device) loss = model_loss(inputs_pre, inputs_post, target) #, disaster_target) loss_sum = torch.sum(loss).detach().cpu() if np.isnan(loss_sum) or np.isinf(loss_sum): print('check') losses.update(loss_sum, 4) # batch size loss = torch.sum(loss) loss.backward() optimizer.step() optimizer.zero_grad() #if args.swa == "True": #scheduler.step() #if epoch%4 == 3 and i == len(trainset_loader)-2: # optimizer.update_swa() if args.local_rank == 0 and i % 10 == 0: logger.info('epoch: {0}\t' 'iter: {1}/{2}\t' 'lr: {3:.6f}\t' 'loss: {loss.val:.4f} ({loss.ema:.4f})'.format( epoch + 1, i + 1, len(trainset_loader), lr, loss=losses)) if args.local_rank == 0: if (epoch + 1) % 50 == 0 and test_model is None: torch.save( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), }, os.path.join(ckpts_save_dir, 'hrnet_%s' % (epoch + 1))) if config.SWA: torch.save( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), }, os.path.join(ckpts_save_dir, 'hrnet_%s' % ("preSWA"))) optimizer.swap_swa_sgd() bn_loader = torch.utils.data.DataLoader( trainset, batch_size=2, shuffle=train_sampler is None, pin_memory=True, drop_last=True, sampler=train_sampler, num_workers=multiprocessing.cpu_count()) bn_update(bn_loader, model, device='cuda') torch.save( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), }, os.path.join(ckpts_save_dir, 'hrnet_%s' % ("SWA")))