def train_network(args, model, reglog, optimizer, loader): """ Train the models on the dataset. """ # running statistics batch_time = AverageMeter() data_time = AverageMeter() # training statistics log_top1 = AverageMeter() log_loss = AverageMeter() end = time.perf_counter() if 'pascal' in args.data_path: criterion = nn.BCEWithLogitsLoss(reduction='none') else: criterion = nn.CrossEntropyLoss().cuda() for iter_epoch, (inp, target) in enumerate(loader): # measure data loading time data_time.update(time.perf_counter() - end) learning_rate_decay(optimizer, len(loader) * args.epoch + iter_epoch, args.lr) # start at iter start_iter if iter_epoch < args.start_iter: continue # move to gpu inp = inp.cuda(non_blocking=True) target = target.cuda(non_blocking=True) if 'pascal' in args.data_path: target = target.float() # forward with torch.no_grad(): output = model(inp) output = reglog(output) # compute cross entropy loss loss = criterion(output, target) if 'pascal' in args.data_path: mask = (target == 255) loss = torch.sum(loss.masked_fill_(mask, 0)) / target.size(0) optimizer.zero_grad() # compute the gradients loss.backward() # step optimizer.step() # log # signal received, relaunch experiment if os.environ['SIGNAL_RECEIVED'] == 'True': if not args.rank: torch.save({ 'epoch': args.epoch, 'start_iter': iter_epoch + 1, 'state_dict': reglog.state_dict(), 'optimizer': optimizer.state_dict(), }, os.path.join(args.dump_path, 'checkpoint.pth.tar')) trigger_job_requeue(os.path.join(args.dump_path, 'checkpoint.pth.tar')) # update stats log_loss.update(loss.item(), output.size(0)) if not 'pascal' in args.data_path: prec1 = accuracy(args, output, target) log_top1.update(prec1.item(), output.size(0)) batch_time.update(time.perf_counter() - end) end = time.perf_counter() # verbose if iter_epoch % 100 == 0: logger.info('Epoch[{0}] - Iter: [{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Prec {log_top1.val:.3f} ({log_top1.avg:.3f})\t' .format(args.epoch, iter_epoch, len(loader), batch_time=batch_time, data_time=data_time, loss=log_loss, log_top1=log_top1)) # end of epoch args.start_iter = 0 args.epoch += 1 # dump checkpoint if not args.rank: torch.save({ 'epoch': args.epoch, 'start_iter': 0, 'state_dict': reglog.state_dict(), 'optimizer': optimizer.state_dict(), }, os.path.join(args.dump_path, 'checkpoint.pth.tar')) return (args.epoch - 1, args.epoch * len(loader), log_top1.avg, log_loss.avg)
def train_network(args, models, optimizers, dataset): """ Train the models with cluster assignments as targets """ # swith to train mode for model in models: model.train() # uniform sampling over pseudo labels sampler = DistUnifTargSampler( args.epoch_size, dataset.sub_classes, args.training_local_world_size, args.training_local_rank, seed=args.epoch + args.training_local_world_id, ) loader = torch.utils.data.DataLoader( dataset, sampler=sampler, batch_size=args.batch_size, num_workers=args.workers, pin_memory=True, ) # running statistics batch_time = AverageMeter() data_time = AverageMeter() # training statistics log_top1_subclass = AverageMeter() log_loss_subclass = AverageMeter() log_top1_superclass = AverageMeter() log_loss_superclass = AverageMeter() log_top1 = AverageMeter() log_loss = AverageMeter() end = time.perf_counter() cel = nn.CrossEntropyLoss().cuda() relu = torch.nn.ReLU().cuda() for iter_epoch, (inp, target) in enumerate(loader): # start at iter start_iter if iter_epoch < args.start_iter: continue # measure data loading time data_time.update(time.perf_counter() - end) # move input to gpu inp = inp.cuda(non_blocking=True) target = target.cuda(non_blocking=True).long() # forward on the model inp = relu(models[0](inp)) # forward on sub-class prediction layer output = models[-1](inp) loss_subclass = cel(output, target) # forward on super-class prediction layer super_class_output = models[1](inp) sc_target = args.training_local_world_id + \ 0 * torch.cuda.LongTensor(args.batch_size) loss_superclass = cel(super_class_output, sc_target) loss = loss_subclass + loss_superclass # initialize the optimizers for optimizer in optimizers: optimizer.zero_grad() # compute the gradients loss.backward() # step for optimizer in optimizers: optimizer.step() # log # signal received, relaunch experiment if os.environ['SIGNAL_RECEIVED'] == 'True': save_checkpoint(args, iter_epoch + 1, models, optimizers) if not args.rank: trigger_job_requeue( os.path.join(args.dump_path, 'checkpoint.pth.tar')) # regular checkpoints if iter_epoch and iter_epoch % 1000 == 0: save_checkpoint(args, iter_epoch + 1, models, optimizers) # update stats log_loss.update(loss.item(), output.size(0)) prec1 = accuracy(args, output, target, sc_output=super_class_output) log_top1.update(prec1.item(), output.size(0)) log_loss_superclass.update(loss_superclass.item(), output.size(0)) prec1 = accuracy(args, super_class_output, sc_target) log_top1_superclass.update(prec1.item(), output.size(0)) log_loss_subclass.update(loss_subclass.item(), output.size(0)) prec1 = accuracy(args, output, target) log_top1_subclass.update(prec1.item(), output.size(0)) batch_time.update(time.perf_counter() - end) end = time.perf_counter() # verbose if iter_epoch % 100 == 0: logger.info( 'Epoch[{0}] - Iter: [{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Prec {log_top1.val:.3f} ({log_top1.avg:.3f})\t' 'Super-class loss: {sc_loss.val:.3f} ({sc_loss.avg:.3f})\t' 'Super-class prec: {sc_prec.val:.3f} ({sc_prec.avg:.3f})\t' 'Intra super-class loss: {los.val:.3f} ({los.avg:.3f})\t' 'Intra super-class prec: {prec.val:.3f} ({prec.avg:.3f})\t'. format(args.epoch, iter_epoch, len(loader), batch_time=batch_time, data_time=data_time, loss=log_loss, log_top1=log_top1, sc_loss=log_loss_superclass, sc_prec=log_top1_superclass, los=log_loss_subclass, prec=log_top1_subclass)) # end of epoch args.start_iter = 0 args.epoch += 1 # dump checkpoint save_checkpoint(args, 0, models, optimizers) if not args.rank: if not (args.epoch - 1) % args.checkpoint_freq: shutil.copyfile( os.path.join(args.dump_path, 'checkpoint.pth.tar'), os.path.join(args.dump_checkpoints, 'checkpoint' + str(args.epoch - 1) + '.pth.tar'), ) return ( args.epoch - 1, args.epoch * len(loader), log_top1.avg, log_loss.avg, log_top1_superclass.avg, log_loss_superclass.avg, log_top1_subclass.avg, log_loss_subclass.avg, )
def train_network(args, model, optimizer, dataset): """ Train the models on the dataset. """ # swith to train mode model.train() sampler = torch.utils.data.distributed.DistributedSampler(dataset) loader = torch.utils.data.DataLoader( dataset, sampler=sampler, batch_size=args.batch_size, num_workers=args.workers, pin_memory=True, ) # running statistics batch_time = AverageMeter() data_time = AverageMeter() # training statistics log_top1 = AverageMeter() log_loss = AverageMeter() end = time.perf_counter() cel = nn.CrossEntropyLoss().cuda() for iter_epoch, (inp, target) in enumerate(loader): # measure data loading time data_time.update(time.perf_counter() - end) # start at iter start_iter if iter_epoch < args.start_iter: continue # move to gpu inp = inp.cuda(non_blocking=True) target = target.cuda(non_blocking=True) # forward output = model(inp) # compute cross entropy loss loss = cel(output, target) optimizer.zero_grad() # compute the gradients loss.backward() # step optimizer.step() # log # signal received, relaunch experiment if os.environ['SIGNAL_RECEIVED'] == 'True': if not args.rank: torch.save( { 'epoch': args.epoch, 'start_iter': iter_epoch + 1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), }, os.path.join(args.dump_path, 'checkpoint.pth.tar')) trigger_job_requeue( os.path.join(args.dump_path, 'checkpoint.pth.tar')) # update stats log_loss.update(loss.item(), output.size(0)) prec1 = accuracy(args, output, target) log_top1.update(prec1.item(), output.size(0)) batch_time.update(time.perf_counter() - end) end = time.perf_counter() # verbose if iter_epoch % 100 == 0: logger.info( 'Epoch[{0}] - Iter: [{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Prec {log_top1.val:.3f} ({log_top1.avg:.3f})\t'.format( args.epoch, iter_epoch, len(loader), batch_time=batch_time, data_time=data_time, loss=log_loss, log_top1=log_top1)) # end of epoch args.start_iter = 0 args.epoch += 1 # dump checkpoint if not args.rank: torch.save( { 'epoch': args.epoch, 'start_iter': 0, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), }, os.path.join(args.dump_path, 'checkpoint.pth.tar')) if not (args.epoch - 1) % args.checkpoint_freq: shutil.copyfile( os.path.join(args.dump_path, 'checkpoint.pth.tar'), os.path.join(args.dump_checkpoints, 'checkpoint' + str(args.epoch - 1) + '.pth.tar'), ) return (args.epoch - 1, args.epoch * len(loader), log_top1.avg, log_loss.avg)