Ejemplo n.º 1
0
def build_model(args):
    model = resnet50(width=args.model_width).cuda()
    model_ema = resnet50(width=args.model_width).cuda()

    # copy weights from `model' to `model_ema'
    moment_update(model, model_ema, 0)

    return model, model_ema
Ejemplo n.º 2
0
def train_moco(epoch, train_loader, model, model_ema, contrast, criterion, optimizer, scheduler, args):
    """
    one epoch training for moco
    """
    model.train()
    set_bn_train(model_ema)

    batch_time = AverageMeter()
    data_time = AverageMeter()
    loss_meter = AverageMeter()
    prob_meter = AverageMeter()

    end = time.time()
    for idx, (inputs, _,) in enumerate(train_loader):
        data_time.update(time.time() - end)

        bsz = inputs.size(0)

        # forward
        x1, x2 = torch.split(inputs, [3, 3], dim=1)
        x1.contiguous()
        x2.contiguous()
        x1 = x1.cuda(non_blocking=True)
        x2 = x2.cuda(non_blocking=True)

        feat_q = model(x1)
        with torch.no_grad():
            x2_shuffled, backward_inds = DistributedShufle.forward_shuffle(x2, epoch)
            feat_k = model_ema(x2_shuffled)
            feat_k_all, feat_k = DistributedShufle.backward_shuffle(feat_k, backward_inds, return_local=True)

        out = contrast(feat_q, feat_k, feat_k_all)
        loss = criterion(out)
        prob = F.softmax(out, dim=1)[:, 0].mean()

        # backward
        optimizer.zero_grad()
        optimizer.zero_grad()
        if args.amp_opt_level != "O0":
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        optimizer.step()
        scheduler.step()

        moment_update(model, model_ema, args.alpha)

        # update meters
        loss_meter.update(loss.item(), bsz)
        prob_meter.update(prob.item(), bsz)
        batch_time.update(time.time() - end)
        end = time.time()

        # print info
        if idx % args.print_freq == 0:
            logger.info(f'Train: [{epoch}][{idx}/{len(train_loader)}]\t'
                        f'T {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                        f'DT {data_time.val:.3f} ({data_time.avg:.3f})\t'
                        f'loss {loss_meter.val:.3f} ({loss_meter.avg:.3f})\t'
                        f'prob {prob_meter.val:.3f} ({prob_meter.avg:.3f})')

    return loss_meter.avg, prob_meter.avg