def main(epochs, cpu, cudnn_flag, visdom_port, visdom_freq, temp_dir, seed, no_bias_decay, label_smoothing, temperature): device = torch.device( 'cuda:0' if torch.cuda.is_available() and not cpu else 'cpu') callback = VisdomLogger(port=visdom_port) if visdom_port else None if cudnn_flag == 'deterministic': setattr(cudnn, cudnn_flag, True) torch.manual_seed(seed) loaders, recall_ks = get_loaders() torch.manual_seed(seed) model = get_model(num_classes=loaders.num_classes) class_loss = SmoothCrossEntropy(epsilon=label_smoothing, temperature=temperature) model.to(device) if torch.cuda.device_count() > 1: model = nn.DataParallel(model) parameters = [] if no_bias_decay: parameters.append( {'params': [par for par in model.parameters() if par.dim() != 1]}) parameters.append({ 'params': [par for par in model.parameters() if par.dim() == 1], 'weight_decay': 0 }) else: parameters.append({'params': model.parameters()}) optimizer, scheduler = get_optimizer_scheduler(parameters=parameters, loader_length=len( loaders.train)) # setup partial function to simplify call eval_function = partial(evaluate, model=model, recall=recall_ks, query_loader=loaders.query, gallery_loader=loaders.gallery) # setup best validation logger metrics = eval_function() if callback is not None: callback.scalars( ['l2', 'cosine'], 0, [metrics.recall['l2'][1], metrics.recall['cosine'][1]], title='Val Recall@1') pprint(metrics.recall) best_val = (0, metrics.recall, deepcopy(model.state_dict())) torch.manual_seed(seed) for epoch in range(epochs): if cudnn_flag == 'benchmark': setattr(cudnn, cudnn_flag, True) train(model=model, loader=loaders.train, class_loss=class_loss, optimizer=optimizer, scheduler=scheduler, epoch=epoch, callback=callback, freq=visdom_freq, ex=ex) # validation if cudnn_flag == 'benchmark': setattr(cudnn, cudnn_flag, False) metrics = eval_function() print('Validation [{:03d}]'.format(epoch)), pprint(metrics.recall) ex.log_scalar('val.recall_l2@1', metrics.recall['l2'][1], step=epoch + 1) ex.log_scalar('val.recall_cosine@1', metrics.recall['cosine'][1], step=epoch + 1) if callback is not None: callback.scalars( ['l2', 'cosine'], epoch + 1, [metrics.recall['l2'][1], metrics.recall['cosine'][1]], title='Val Recall') # save model dict if the chosen validation metric is better if metrics.recall['cosine'][1] >= best_val[1]['cosine'][1]: best_val = (epoch + 1, metrics.recall, deepcopy(model.state_dict())) # logging ex.info['recall'] = best_val[1] # saving save_name = os.path.join( temp_dir, '{}_{}.pt'.format(ex.current_run.config['model']['arch'], ex.current_run.config['dataset']['name'])) torch.save(state_dict_to_cpu(best_val[2]), save_name) ex.add_artifact(save_name) if callback is not None: save_name = os.path.join(temp_dir, 'visdom_data.pt') callback.save(save_name) ex.add_artifact(save_name) return best_val[1]['cosine'][1]
def do_epoch(args: argparse.Namespace, train_loader: torch.utils.data.DataLoader, model: DDP, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler, epoch: int, callback: VisdomLogger, iter_per_epoch: int, log_iter: int) -> Tuple[torch.tensor, torch.tensor]: loss_meter = AverageMeter() train_losses = torch.zeros(log_iter).to(dist.get_rank()) train_mIous = torch.zeros(log_iter).to(dist.get_rank()) iterable_train_loader = iter(train_loader) if main_process(args): bar = tqdm(range(iter_per_epoch)) else: bar = range(iter_per_epoch) for i in bar: model.train() current_iter = epoch * len(train_loader) + i + 1 images, gt = iterable_train_loader.next() images = images.to(dist.get_rank(), non_blocking=True) gt = gt.to(dist.get_rank(), non_blocking=True) loss = compute_loss( args=args, model=model, images=images, targets=gt.long(), num_classes=args.num_classes_tr, ) optimizer.zero_grad() loss.backward() optimizer.step() if args.scheduler == 'cosine': scheduler.step() if i % args.log_freq == 0: model.eval() logits = model(images) intersection, union, target = intersectionAndUnionGPU( logits.argmax(1), gt, args.num_classes_tr, 255) if args.distributed: dist.all_reduce(loss) dist.all_reduce(intersection) dist.all_reduce(union) dist.all_reduce(target) allAcc = (intersection.sum() / (target.sum() + 1e-10)) # scalar mAcc = (intersection / (target + 1e-10)).mean() mIoU = (intersection / (union + 1e-10)).mean() loss_meter.update(loss.item() / dist.get_world_size()) if main_process(args): if callback is not None: t = current_iter / len(train_loader) callback.scalar('loss_train_batch', t, loss_meter.avg, title='Loss') callback.scalars(['mIoU', 'mAcc', 'allAcc'], t, [mIoU, mAcc, allAcc], title='Training metrics') for index, param_group in enumerate( optimizer.param_groups): lr = param_group['lr'] callback.scalar('lr', t, lr, title='Learning rate') break train_losses[int(i / args.log_freq)] = loss_meter.avg train_mIous[int(i / args.log_freq)] = mIoU if args.scheduler != 'cosine': scheduler.step() return train_mIous, train_losses