def build_model(args): model = ResNet18().cuda() model_ema = ResNet18().cuda() # copy weights from `model' to `model_ema' moment_update(model, model_ema, 0) return model, model_ema
def build_model(args): model = resnet50(width=args.model_width).cuda() model_ema = resnet50(width=args.model_width).cuda() # copy weights from `model' to `model_ema' moment_update(model, model_ema, 0) return model, model_ema
def train_moco(epoch, train_loader, model, model_ema, contrast, criterion, optimizer, scheduler, args): """ one epoch training for moco """ model.train() set_bn_train(model_ema) batch_time = AverageMeter() data_time = AverageMeter() loss_meter = AverageMeter() prob_meter = AverageMeter() end = time.time() for idx, ((x1, x2), _) in enumerate(train_loader): data_time.update(time.time() - end) bsz = x1.size(0) # forward x1.contiguous() x2.contiguous() x1 = x1.cuda(non_blocking=True) x2 = x2.cuda(non_blocking=True) feat_q = model(x1) with torch.no_grad(): feat_k = model_ema(x2) out = contrast(feat_q, feat_k, feat_k) loss = criterion(out) prob = F.softmax(out, dim=1)[:, 0].mean() # backward optimizer.zero_grad() loss.backward() optimizer.step() scheduler.step() moment_update(model, model_ema, args.alpha) # update meters loss_meter.update(loss.item(), bsz) prob_meter.update(prob.item(), bsz) batch_time.update(time.time() - end) end = time.time() # print info if args.local_rank == 0 and idx % args.print_freq == 0: print(f'Train: [{epoch}][{idx}/{len(train_loader)}]\t' f'T {batch_time.val:.3f} ({batch_time.avg:.3f})\t' f'DT {data_time.val:.3f} ({data_time.avg:.3f})\t' f'loss {loss_meter.val:.3f} ({loss_meter.avg:.3f})\t' f'prob {prob_meter.val:.3f} ({prob_meter.avg:.3f})') return loss_meter.avg, prob_meter.avg