def build_model(args): model = resnet50(width=args.model_width).cuda() model_ema = resnet50(width=args.model_width).cuda() # copy weights from `model' to `model_ema' moment_update(model, model_ema, 0) return model, model_ema
def train_moco(epoch, train_loader, model, model_ema, contrast, criterion, optimizer, scheduler, args): """ one epoch training for moco """ model.train() set_bn_train(model_ema) batch_time = AverageMeter() data_time = AverageMeter() loss_meter = AverageMeter() prob_meter = AverageMeter() end = time.time() for idx, (inputs, _,) in enumerate(train_loader): data_time.update(time.time() - end) bsz = inputs.size(0) # forward x1, x2 = torch.split(inputs, [3, 3], dim=1) x1.contiguous() x2.contiguous() x1 = x1.cuda(non_blocking=True) x2 = x2.cuda(non_blocking=True) feat_q = model(x1) with torch.no_grad(): x2_shuffled, backward_inds = DistributedShufle.forward_shuffle(x2, epoch) feat_k = model_ema(x2_shuffled) feat_k_all, feat_k = DistributedShufle.backward_shuffle(feat_k, backward_inds, return_local=True) out = contrast(feat_q, feat_k, feat_k_all) loss = criterion(out) prob = F.softmax(out, dim=1)[:, 0].mean() # backward optimizer.zero_grad() optimizer.zero_grad() if args.amp_opt_level != "O0": with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() optimizer.step() scheduler.step() moment_update(model, model_ema, args.alpha) # update meters loss_meter.update(loss.item(), bsz) prob_meter.update(prob.item(), bsz) batch_time.update(time.time() - end) end = time.time() # print info if idx % args.print_freq == 0: logger.info(f'Train: [{epoch}][{idx}/{len(train_loader)}]\t' f'T {batch_time.val:.3f} ({batch_time.avg:.3f})\t' f'DT {data_time.val:.3f} ({data_time.avg:.3f})\t' f'loss {loss_meter.val:.3f} ({loss_meter.avg:.3f})\t' f'prob {prob_meter.val:.3f} ({prob_meter.avg:.3f})') return loss_meter.avg, prob_meter.avg