Пример #1
0
def main(epochs, cpu, cudnn_flag, visdom_port, visdom_freq, temp_dir, seed,
         no_bias_decay, label_smoothing, temperature):
    device = torch.device(
        'cuda:0' if torch.cuda.is_available() and not cpu else 'cpu')
    callback = VisdomLogger(port=visdom_port) if visdom_port else None
    if cudnn_flag == 'deterministic':
        setattr(cudnn, cudnn_flag, True)

    torch.manual_seed(seed)
    loaders, recall_ks = get_loaders()

    torch.manual_seed(seed)
    model = get_model(num_classes=loaders.num_classes)
    class_loss = SmoothCrossEntropy(epsilon=label_smoothing,
                                    temperature=temperature)

    model.to(device)
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    parameters = []
    if no_bias_decay:
        parameters.append(
            {'params': [par for par in model.parameters() if par.dim() != 1]})
        parameters.append({
            'params': [par for par in model.parameters() if par.dim() == 1],
            'weight_decay':
            0
        })
    else:
        parameters.append({'params': model.parameters()})
    optimizer, scheduler = get_optimizer_scheduler(parameters=parameters,
                                                   loader_length=len(
                                                       loaders.train))

    # setup partial function to simplify call
    eval_function = partial(evaluate,
                            model=model,
                            recall=recall_ks,
                            query_loader=loaders.query,
                            gallery_loader=loaders.gallery)

    # setup best validation logger
    metrics = eval_function()
    if callback is not None:
        callback.scalars(
            ['l2', 'cosine'],
            0, [metrics.recall['l2'][1], metrics.recall['cosine'][1]],
            title='Val Recall@1')
    pprint(metrics.recall)
    best_val = (0, metrics.recall, deepcopy(model.state_dict()))

    torch.manual_seed(seed)
    for epoch in range(epochs):
        if cudnn_flag == 'benchmark':
            setattr(cudnn, cudnn_flag, True)

        train(model=model,
              loader=loaders.train,
              class_loss=class_loss,
              optimizer=optimizer,
              scheduler=scheduler,
              epoch=epoch,
              callback=callback,
              freq=visdom_freq,
              ex=ex)

        # validation
        if cudnn_flag == 'benchmark':
            setattr(cudnn, cudnn_flag, False)
        metrics = eval_function()
        print('Validation [{:03d}]'.format(epoch)), pprint(metrics.recall)
        ex.log_scalar('val.recall_l2@1',
                      metrics.recall['l2'][1],
                      step=epoch + 1)
        ex.log_scalar('val.recall_cosine@1',
                      metrics.recall['cosine'][1],
                      step=epoch + 1)

        if callback is not None:
            callback.scalars(
                ['l2', 'cosine'],
                epoch + 1,
                [metrics.recall['l2'][1], metrics.recall['cosine'][1]],
                title='Val Recall')

        # save model dict if the chosen validation metric is better
        if metrics.recall['cosine'][1] >= best_val[1]['cosine'][1]:
            best_val = (epoch + 1, metrics.recall,
                        deepcopy(model.state_dict()))

    # logging
    ex.info['recall'] = best_val[1]

    # saving
    save_name = os.path.join(
        temp_dir, '{}_{}.pt'.format(ex.current_run.config['model']['arch'],
                                    ex.current_run.config['dataset']['name']))
    torch.save(state_dict_to_cpu(best_val[2]), save_name)
    ex.add_artifact(save_name)

    if callback is not None:
        save_name = os.path.join(temp_dir, 'visdom_data.pt')
        callback.save(save_name)
        ex.add_artifact(save_name)

    return best_val[1]['cosine'][1]
def do_epoch(args: argparse.Namespace,
             train_loader: torch.utils.data.DataLoader, model: DDP,
             optimizer: torch.optim.Optimizer,
             scheduler: torch.optim.lr_scheduler, epoch: int,
             callback: VisdomLogger, iter_per_epoch: int,
             log_iter: int) -> Tuple[torch.tensor, torch.tensor]:
    loss_meter = AverageMeter()
    train_losses = torch.zeros(log_iter).to(dist.get_rank())
    train_mIous = torch.zeros(log_iter).to(dist.get_rank())

    iterable_train_loader = iter(train_loader)

    if main_process(args):
        bar = tqdm(range(iter_per_epoch))
    else:
        bar = range(iter_per_epoch)

    for i in bar:
        model.train()
        current_iter = epoch * len(train_loader) + i + 1

        images, gt = iterable_train_loader.next()
        images = images.to(dist.get_rank(), non_blocking=True)
        gt = gt.to(dist.get_rank(), non_blocking=True)

        loss = compute_loss(
            args=args,
            model=model,
            images=images,
            targets=gt.long(),
            num_classes=args.num_classes_tr,
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if args.scheduler == 'cosine':
            scheduler.step()

        if i % args.log_freq == 0:
            model.eval()
            logits = model(images)
            intersection, union, target = intersectionAndUnionGPU(
                logits.argmax(1), gt, args.num_classes_tr, 255)
            if args.distributed:
                dist.all_reduce(loss)
                dist.all_reduce(intersection)
                dist.all_reduce(union)
                dist.all_reduce(target)

            allAcc = (intersection.sum() / (target.sum() + 1e-10))  # scalar
            mAcc = (intersection / (target + 1e-10)).mean()
            mIoU = (intersection / (union + 1e-10)).mean()
            loss_meter.update(loss.item() / dist.get_world_size())

            if main_process(args):
                if callback is not None:
                    t = current_iter / len(train_loader)
                    callback.scalar('loss_train_batch',
                                    t,
                                    loss_meter.avg,
                                    title='Loss')
                    callback.scalars(['mIoU', 'mAcc', 'allAcc'],
                                     t, [mIoU, mAcc, allAcc],
                                     title='Training metrics')
                    for index, param_group in enumerate(
                            optimizer.param_groups):
                        lr = param_group['lr']
                        callback.scalar('lr', t, lr, title='Learning rate')
                        break

                train_losses[int(i / args.log_freq)] = loss_meter.avg
                train_mIous[int(i / args.log_freq)] = mIoU

    if args.scheduler != 'cosine':
        scheduler.step()

    return train_mIous, train_losses