def validate(val_loader, model, args): batch_time = AverageMeter() data_time = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() # switch to evaluate mode model.eval() with torch.no_grad(): end = time.time() for inputs, labels in val_loader: data_time.update(time.time() - end) inputs, labels = inputs.cuda(), labels.cuda() outputs = model(inputs) acc1, acc5 = accuracy(outputs, labels, topk=(1, 5)) top1.update(acc1.item(), inputs.size(0)) top5.update(acc5.item(), inputs.size(0)) batch_time.update(time.time() - end) end = time.time() throughput = 1.0 / (batch_time.avg / inputs.size(0)) return top1.avg, top5.avg, throughput
def validate(val_loader, model, args): batch_time = AverageMeter() data_time = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() # switch to evaluate mode model.eval() result = 0 with torch.no_grad(): end = time.time() labels_len = 0 for inputs, labels in val_loader: data_time.update(time.time() - end) inputs, labels = inputs.cuda(), labels.cuda() outputs = model(inputs) outputs = F.softmax(outputs, dim=1) acc1, acc5 = accuracy(outputs, labels, topk=(1, 2)) top1.update(acc1.item(), inputs.size(0)) top5.update(acc5.item(), inputs.size(0)) batch_time.update(time.time() - end) end = time.time() # compute softmax loss pred_score = outputs.cuda().data.cpu().numpy() for i in range(len(labels)): for j in range(4): if labels[i] == j: result += math.log(pred_score[i][j]) labels_len += len(labels) result = -result/labels_len logger.info(f'rb loss: {result}') throughput = 1.0 / (batch_time.avg / inputs.size(0)) return top1.avg, top5.avg, throughput, result
def train(train_loader, model, criterion, optimizer, scheduler, epoch, args): top1 = AverageMeter() top5 = AverageMeter() losses = AverageMeter() # switch to train mode model.train() iters = len(train_loader.dataset) // (args.per_node_batch_size * gpus_num) prefetcher = DataPrefetcher(train_loader) inputs, labels = prefetcher.next() iter_index = 1 while inputs is not None: inputs, labels = inputs.cuda(), labels.cuda() outputs = model(inputs) loss = criterion(outputs, labels) loss = loss / args.accumulation_steps if args.apex: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() if iter_index % args.accumulation_steps == 0: optimizer.step() optimizer.zero_grad() # measure accuracy and record loss acc1, acc5 = accuracy(outputs, labels, topk=(1, 5)) top1.update(acc1.item(), inputs.size(0)) top5.update(acc5.item(), inputs.size(0)) losses.update(loss.item(), inputs.size(0)) inputs, labels = prefetcher.next() if local_rank == 0 and iter_index % args.print_interval == 0: logger.info( f"train: epoch {epoch:0>3d}, iter [{iter_index:0>4d}, {iters:0>4d}], lr: {scheduler.get_lr()[0]:.6f}, top1 acc: {acc1.item():.2f}%, top5 acc: {acc5.item():.2f}%, loss_total: {loss.item():.2f}" ) iter_index += 1 scheduler.step() return top1.avg, top5.avg, losses.avg
def main(logger, args): if not torch.cuda.is_available(): raise Exception("need gpu to train network!") if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) cudnn.deterministic = True gpus = torch.cuda.device_count() logger.info(f'use {gpus} gpus') logger.info(f"args: {args}") cudnn.benchmark = True cudnn.enabled = True start_time = time.time() train_loader = DataLoader(Config.train_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.num_workers) val_loader = DataLoader(Config.val_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=args.num_workers) logger.info('finish loading data') # dataset and dataloader logger.info('start loading data') model_b3 = models.__dict__['efficientnet_b3'](**{ "pretrained": False, "num_classes": 4, }) model_12GF = models.__dict__['RegNetY_12GF'](**{ "pretrained": False, "num_classes": 4, }) ''' model_vovnet = models.__dict__['VoVNet99_se'](**{ "pretrained": False, "num_classes": 4, }) ''' model_32GF = models.__dict__['RegNetY_32GF'](**{ "pretrained": False, "num_classes": 4, }) for name, param in model_b3.named_parameters(): param.requires_grad = False logger.info(f"{name},{param.requires_grad}") for name, param in model_12GF.named_parameters(): param.requires_grad = False logger.info(f"{name},{param.requires_grad}") ''' for name, param in model_vovnet.named_parameters(): param.requires_grad = False logger.info(f"{name},{param.requires_grad}") ''' for name, param in model_32GF.named_parameters(): param.requires_grad = False logger.info(f"{name},{param.requires_grad}") # merge model logger.info(f"creating ensemble model") #model = JSTNET(model_b3, model_12GF, model_32GF, model_vovnet) model = JSTNET(model_b3, model_12GF, model_32GF) model = model.cuda() model_b3 = model_b3.cuda() model_12GF = model_12GF.cuda() model_32GF = model_32GF.cuda() #model_vovnet = model_vovnet.cuda() criterion = nn.CrossEntropyLoss().cuda() optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # warm_up_with_cosine_lr warm_up_with_cosine_lr = lambda epoch: epoch / args.warm_up_epochs if epoch <= args.warm_up_epochs else 0.5 * ( math.cos((epoch - args.warm_up_epochs) / (args.epochs - args.warm_up_epochs) * math.pi) + 1) scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=warm_up_with_cosine_lr) if args.apex: model, optimizer = amp.initialize(model, optimizer, opt_level='O1') model = nn.DataParallel(model) #load model my_path = '/home/jns2szh/code/pytorch-ImageNet-CIFAR-COCO-VOC-training-master/imagenet_experiments/' logger.info(f"start load model") checkpoint_b3 = torch.load(my_path+'efficientnet_imagenet_DataParallel_train_example/checkpoints_b3/latest.pth', map_location=torch.device('cpu')) new_state_dict = OrderedDict() for k, v in checkpoint_b3['model_state_dict'].items(): name = k[7:] # remove module. if name != 'fc.weight' and name != 'fc.bias': new_state_dict[name] = v model_b3.load_state_dict(new_state_dict) logger.info(f"load b3 model finished") checkpoint_12GF = torch.load(my_path+'regnet_imagenet_Dataparallel_train_example/regnet_12/latest.pth', map_location=torch.device('cpu')) new_state_dict = OrderedDict() for k, v in checkpoint_12GF['model_state_dict'].items(): name = k[7:] # remove module. if name != 'fc.weight' and name != 'fc.bias': new_state_dict[name] = v model_12GF.load_state_dict(new_state_dict) logger.info(f"load 12GF model finished") ''' checkpoint_vovnet = torch.load(my_path+'vovnet_Dataparallel_train_example/checkpoints/latest.pth', map_location=torch.device('cpu')) new_state_dict = OrderedDict() for k, v in checkpoint_vovnet['model_state_dict'].items(): name = k[7:] # remove module. if name != 'fc.weight' and name != 'fc.bias': new_state_dict[name] = v model_vovnet.load_state_dict(new_state_dict) logger.info(f"load vovnet model finished") ''' checkpoint_32GF = torch.load(my_path+'regnet_imagenet_Dataparallel_train_example/checkpoints/latest.pth', map_location=torch.device('cpu')) new_state_dict = OrderedDict() for k, v in checkpoint_32GF['model_state_dict'].items(): name = k[7:] # remove module. if name != 'fc.weight' and name != 'fc.bias': new_state_dict[name] = v model_32GF.load_state_dict(new_state_dict) logger.info(f"load 32GF model finished") # resume training start_epoch=0 if os.path.exists(args.resume): logger.info(f"start resuming model from {args.resume}") checkpoint = torch.load(args.resume, map_location=torch.device('cpu')) start_epoch += checkpoint['epoch'] model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) scheduler.load_state_dict(checkpoint['scheduler_state_dict']) logger.info( f"finish resuming model from {args.resume}, epoch {checkpoint['epoch']}, " f"loss: {checkpoint['loss']:3f}, lr: {checkpoint['lr']:.6f}, " f"top1_acc: {checkpoint['acc1']}%") if not os.path.exists(args.checkpoints): os.makedirs(args.checkpoints) logger.info('start training') min_rb_loss = 1000 for epoch in range(start_epoch, args.epochs + 1): #print(epoch, logger,args) ''' acc1, losses = train(train_loader, model, criterion, optimizer, epoch, logger) ''' top1 = AverageMeter() top5 = AverageMeter() losses = AverageMeter() # switch to train mode model.train() iters = len(train_loader.dataset) // args.batch_size prefetcher = DataPrefetcher(train_loader) inputs, labels = prefetcher.next() iter_index = 1 while inputs is not None: inputs, labels = inputs.cuda(), labels.cuda() outputs = model(inputs) loss = criterion(outputs, labels) loss = loss / 1 if args.apex: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() if iter_index % 1 == 0: optimizer.step() optimizer.zero_grad() # measure accuracy and record loss acc1, acc5 = accuracy(outputs, labels, topk=(1, 2)) top1.update(acc1.item(), inputs.size(0)) top5.update(acc5.item(), inputs.size(0)) losses.update(loss.item(), inputs.size(0)) inputs, labels = prefetcher.next() if iter_index % args.print_interval == 0: logger.info( f"train: epoch {epoch:0>3d}, iter [{iter_index:0>4d}, {iters:0>4d}], lr: {scheduler.get_lr()[0]:.6f}, top1 acc: {acc1.item():.2f}%, top5 acc: {acc5.item():.2f}%, loss_total: {loss.item():.2f}" ) iter_index += 1 scheduler.step() ''' logger.info( f"train: epoch {epoch:0>3d}, top1 acc: {acc1:.2f}%, losses: {losses:.2f}" ) ''' acc1, acc5, throughput, rb_loss = validate(val_loader, model) logger.info( f"val: epoch {epoch:0>3d}, top1 acc: {acc1:.2f}%, top5 acc: {acc5:.2f}%, throughput: {throughput:.2f}sample/s" ) if rb_loss < min_rb_loss: min_rb_loss = rb_loss logger.info("save model") torch.save( { 'epoch': epoch, 'acc1': acc1, 'loss': losses, 'lr': scheduler.get_lr()[0], 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), }, os.path.join(args.checkpoints, 'latest.pth')) if epoch == args.epochs: torch.save( { 'epoch': epoch, 'acc1': acc1, 'loss': losses, 'lr': scheduler.get_lr()[0], 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), }, os.path.join( args.checkpoints, "{}-epoch{}-acc{}.pth".format('JSTNET', epoch, acc1))) training_time = (time.time() - start_time) / 3600 logger.info( f"finish training, total training time: {training_time:.2f} hours")