Ejemplo n.º 1
0
def test_grad_corr(dataloader, net, ssh, ext):
    criterion = nn.CrossEntropyLoss().cuda()
    net.eval()
    ssh.eval()
    corr = []
    for batch_idx, (inputs, labels) in enumerate(dataloader):
        net.zero_grad()
        ssh.zero_grad()
        inputs_cls, labels_cls = inputs.cuda(), labels.cuda()
        outputs_cls = net(inputs_cls)
        loss_cls = criterion(outputs_cls, labels_cls)
        grad_cls = torch.autograd.grad(loss_cls, ext.parameters())
        grad_cls = flat_grad(grad_cls)

        ext.zero_grad()
        inputs, labels = rotate_batch(inputs, 'expand')
        inputs_ssh, labels_ssh = inputs.cuda(), labels.cuda()
        outputs_ssh = ssh(inputs_ssh)
        loss_ssh = criterion(outputs_ssh, labels_ssh)
        grad_ssh = torch.autograd.grad(loss_ssh, ext.parameters())
        grad_ssh = flat_grad(grad_ssh)

        corr.append(torch.dot(grad_cls, grad_ssh).item())
    net.train()
    ssh.train()
    return corr
Ejemplo n.º 2
0
def adapt_single(model, image, optimizer, criterion, niter, batch_size):
	model.train()
	for iteration in range(niter):
		inputs = [rotation_tr_transforms(image) for _ in range(batch_size)]
		inputs, labels = rotate_batch(inputs)
		inputs, labels = inputs.cuda(), labels.cuda()
		optimizer.zero_grad()
		outputs = model(inputs)
		loss = criterion(outputs, labels)
		loss.backward()
		optimizer.step()
Ejemplo n.º 3
0
def adapt_single_tensor(model, tensor, optimizer, criterion, niter, batch_size):
	model.train()
	for iteration in range(niter):
		inputs = [tensor for _ in range(batch_size)]
		inputs, labels = rotate_batch(inputs)
		inputs, labels = inputs.to(device), labels.to(device)
		optimizer.zero_grad()
		_, ssh = model(inputs)
		loss = criterion(ssh, labels)
		loss.backward()
		optimizer.step()
Ejemplo n.º 4
0
def test(dataloader, model, sslabel=None):
    criterion = nn.CrossEntropyLoss(reduction='none').cuda()
    model.eval()  # 让 model变成测试模式
    correct = []
    losses = []
    for batch_idx, (inputs, labels) in enumerate(dataloader):
        if sslabel is not None:
            inputs, labels = rotate_batch(inputs, sslabel)  # 在这里rotate
        inputs, labels = inputs.cuda(), labels.cuda()
        with torch.no_grad():
            outputs = model(inputs)  # 就看这个能不能送进去
            loss = criterion(outputs, labels)  # 就是识别任务,判断能否得出正确的label
            losses.append(loss.cpu())
            _, predicted = outputs.max(1)
            correct.append(predicted.eq(labels).cpu())  # 是否预测正确
    correct = torch.cat(correct).numpy()  # 这个回头要测试一下
    losses = torch.cat(losses).numpy()
    model.train()  # 变回训练模式
    return 1 - correct.mean(), correct, losses
Ejemplo n.º 5
0
def test(dataloader, model, sslabel=None):
    criterion = nn.CrossEntropyLoss(reduction='none').cuda()
    model.eval()
    correct = []
    losses = []
    for batch_idx, (inputs, labels) in enumerate(dataloader):
        if sslabel is not None:
            inputs, labels = rotate_batch(inputs, sslabel)
        inputs, labels = inputs.cuda(), labels.cuda()
        with torch.no_grad():
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            losses.append(loss.cpu())
            _, predicted = outputs.max(1)
            correct.append(predicted.eq(labels).cpu())
    correct = torch.cat(correct).numpy()
    losses = torch.cat(losses).numpy()
    model.train()
    return 1 - correct.mean(), correct, losses
Ejemplo n.º 6
0
def train(trloader, epoch):
    net.train()
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    progress = ProgressMeter(len(trloader),
                             batch_time,
                             data_time,
                             losses,
                             top1,
                             prefix="Epoch: [{}]".format(epoch))

    end = time.time()
    for i, dl in enumerate(trloader):
        data_time.update(time.time() - end)
        optimizer.zero_grad()

        inputs_cls, labels_cls = dl[0].to(device), dl[1].to(device)
        outputs_cls, _ = net(inputs_cls)
        loss = criterion(outputs_cls, labels_cls)
        losses.update(loss.item(), len(labels_cls))

        _, predicted = outputs_cls.max(1)
        acc1 = predicted.eq(labels_cls).sum().item() / len(labels_cls)
        top1.update(acc1, len(labels_cls))

        rot_inputs, rot_labels = rotate_batch(dl[0])
        inputs_ssh, labels_ssh = rot_inputs.to(device), rot_labels.to(device)
        _, outputs_ssh = net(inputs_ssh)
        loss_ssh = criterion(outputs_ssh, labels_ssh)
        loss += loss_ssh

        loss.backward()
        optimizer.step()

        batch_time.update(time.time() - end)
        end = time.time()
        if i % args.print_freq == 0:
            progress.print(i)
Ejemplo n.º 7
0
all_err_cls = []
all_err_ssh = []
print('Running...')
print('Error (%)\t\ttest\t\tself-supervised')
for epoch in range(1, args.nepoch + 1):
    net.train()
    ssh.train()

    for batch_idx, (inputs, labels) in enumerate(trloader):
        optimizer.zero_grad()
        inputs_cls, labels_cls = inputs.cuda(), labels.cuda()
        outputs_cls = net(inputs_cls)
        loss = criterion(outputs_cls, labels_cls)

        if args.shared is not None:
            inputs_ssh, labels_ssh = rotate_batch(
                inputs, args.rotation_type)  # train randomly
            inputs_ssh, labels_ssh = inputs_ssh.cuda(), labels_ssh.cuda()
            outputs_ssh = ssh(inputs_ssh)
            loss_ssh = criterion(outputs_ssh, labels_ssh)
            loss += loss_ssh

        loss.backward()
        optimizer.step()

    err_cls = test(teloader, net)[0]
    err_ssh = 0 if args.shared is None else test(
        teloader, ssh, sslabel='expand')[0]
    all_err_cls.append(err_cls)
    all_err_ssh.append(err_ssh)
    scheduler.step()
Ejemplo n.º 8
0
    

    if epoch < args.epochs_pre:
        net.train()
        ssh.train()
        c_err_cls = []
        avg_err_cls = 0
        for i, (inputs, labels) in enumerate(train_loader):
            # print("inside train")
            optimizer.zero_grad()
            X, label = inputs.to(device), labels.to(device)
            y = net(X)
            loss = criterion(y, label)

            if args.shared is not None:
                inputs_ssh, labels_ssh = rotate_batch(inputs, args.rotation_type)
                inputs_ssh, labels_ssh = inputs_ssh.to(device), labels_ssh.to(device)
                outputs_ssh = ssh(inputs_ssh)
                loss_ssh = criterion(outputs_ssh, labels_ssh)
                loss += loss_ssh

            loss.backward()
            optimizer.step()
        scheduler.step()

    
    else:   
        if epoch % args.aug_freq == 0 or epoch == args.epochs_pre:
            AUG_STEP += 1
            net.eval()
            
Ejemplo n.º 9
0
trset, trloader = prepare_train_data(args)
teset, teloader = prepare_test_data(args)

corrs = []
print("Gradient Correlation")
for i in range(args.epochs):
    idx = random.randint(0, len(trset) - 1)
    img, lbl = trset[idx]
    random_rot = random.randint(1, 3)
    rot_img = rotate_single_with_label(img, random_rot)

    # get gradient loss for auxiliary head
    d_aux_loss = []
    inputs = [rot_img for _ in range(args.batch_size)]
    inputs, labels = rotate_batch(inputs)
    inputs, labels = inputs.to(device), labels.to(device)
    optimizer.zero_grad()
    _, ssh = net(inputs)
    loss = criterion(ssh, labels)
    loss.backward(retain_graph=True)

    for p in net.parameters():
        if p.grad is None:
            continue
        # split point
        if list(p.grad.size())[0] == 512:
            break
        d_aux_loss.append(p.grad.data.clone())

    # get gradient loss for main head