Пример #1
0
def train_s2m2(base_loader, base_loader_test, val_loader, model, start_epoch,
               stop_epoch, params, tmp):

    if params.dct_status:
        channels = params.channels
    else:
        channels = 3

    val_acc_best = 0.0

    if not os.path.isdir(params.checkpoint_dir):
        os.makedirs(params.checkpoint_dir)

    if path.exists(params.checkpoint_dir + '/val_' + params.dataset + '.pt'):
        loader = torch.load(params.checkpoint_dir + '/val_' + params.dataset +
                            '.pt')
    else:
        loader = []
        for _, (x, _) in enumerate(val_loader):
            loader.append(x)
        torch.save(loader,
                   params.checkpoint_dir + '/val_' + params.dataset + '.pt')

    def mixup_criterion(criterion, pred, y_a, y_b, lam):
        return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

    criterion = nn.CrossEntropyLoss()

    if params.model == 'WideResNet28_10':
        rotate_classifier = nn.Sequential(nn.Linear(640, 4))
    elif params.model == 'ResNet18':
        rotate_classifier = nn.Sequential(nn.Linear(512, 4))

    rotate_classifier.cuda()

    if 'rotate' in tmp:
        print("loading rotate model")
        rotate_classifier.load_state_dict(tmp['rotate'])

    optimizer = torch.optim.Adam([{
        'params': model.parameters()
    }, {
        'params': rotate_classifier.parameters()
    }])

    print("stop_epoch", start_epoch, stop_epoch)

    for epoch in range(start_epoch, stop_epoch):
        print('\nEpoch: %d' % epoch)

        model.train()
        train_loss = 0
        rotate_loss = 0
        correct = 0
        total = 0
        torch.cuda.empty_cache()
        print("inside base_loader: ", len(base_loader))
        for batch_idx, (inputs, targets) in enumerate(base_loader):
            if use_gpu:
                inputs, targets = inputs.cuda(), targets.cuda()
            #print("shape of input: ", inputs.shape)
            lam = np.random.beta(params.alpha, params.alpha)
            f, outputs, target_a, target_b = model(inputs,
                                                   targets,
                                                   mixup_hidden=True,
                                                   mixup_alpha=params.alpha,
                                                   lam=lam)
            loss = mixup_criterion(criterion, outputs, target_a, target_b, lam)
            train_loss += loss.data.item()
            optimizer.zero_grad()
            loss.backward()

            _, predicted = torch.max(outputs.data, 1)
            total += targets.size(0)
            correct += (
                lam * predicted.eq(target_a.data).cpu().sum().float() +
                (1 - lam) * predicted.eq(target_b.data).cpu().sum().float())

            bs = inputs.size(0)
            inputs_ = []
            targets_ = []
            a_ = []
            indices = np.arange(bs)
            np.random.shuffle(indices)

            split_size = int(bs / 4)
            for j in indices[0:split_size]:
                x90 = inputs[j].transpose(2, 1).flip(1)
                x180 = x90.transpose(2, 1).flip(1)
                x270 = x180.transpose(2, 1).flip(1)
                inputs_ += [inputs[j], x90, x180, x270]
                targets_ += [targets[j] for _ in range(4)]
                a_ += [
                    torch.tensor(0),
                    torch.tensor(1),
                    torch.tensor(2),
                    torch.tensor(3)
                ]

            inputs = Variable(torch.stack(inputs_, 0))
            targets = Variable(torch.stack(targets_, 0))
            a_ = Variable(torch.stack(a_, 0))

            if use_gpu:
                inputs = inputs.cuda()
                targets = targets.cuda()
                a_ = a_.cuda()

            rf, outputs = model(inputs)
            rotate_outputs = rotate_classifier(rf)
            rloss = criterion(rotate_outputs, a_)
            closs = criterion(outputs, targets)
            loss = (rloss + closs) / 2.0

            rotate_loss += rloss.data.item()

            loss.backward()

            optimizer.step()

            if batch_idx % 50 == 0:
                print(
                    '{0}/{1}'.format(batch_idx, len(base_loader)),
                    'Loss: %.3f | Acc: %.3f%% | RotLoss: %.3f  ' %
                    (train_loss /
                     (batch_idx + 1), 100. * correct / total, rotate_loss /
                     (batch_idx + 1)))

        if not os.path.isdir(params.checkpoint_dir):
            os.makedirs(params.checkpoint_dir)

        if (epoch % params.save_freq == 0) or (epoch == stop_epoch - 1):
            outfile = os.path.join(params.checkpoint_dir,
                                   '{:d}.tar'.format(epoch))
            torch.save({'epoch': epoch, 'state': model.state_dict()}, outfile)

        model.eval()
        with torch.no_grad():
            test_loss = 0
            correct = 0
            total = 0
            for batch_idx, (inputs, targets) in enumerate(base_loader_test):
                if use_gpu:
                    inputs, targets = inputs.cuda(), targets.cuda()
                inputs, targets = Variable(inputs), Variable(targets)
                f, outputs = model.forward(inputs)
                loss = criterion(outputs, targets)
                test_loss += loss.data.item()
                _, predicted = torch.max(outputs.data, 1)
                total += targets.size(0)
                correct += predicted.eq(targets.data).cpu().sum()

            print('Loss: %.3f | Acc: %.3f%%' %
                  (test_loss / (batch_idx + 1), 100. * correct / total))

        if params.dct_status:

            valmodel = BaselineFinetune(model_dict[params.model + '_dct'],
                                        params.train_n_way,
                                        params.n_shot,
                                        loss_type='dist')
        else:
            valmodel = BaselineFinetune(model_dict[params.model],
                                        params.train_n_way,
                                        params.n_shot,
                                        loss_type='dist')
        valmodel.n_query = 15
        acc_all1, acc_all2, acc_all3 = [], [], []
        for i, x in enumerate(loader):
            if params.dct_status:
                x = x.view(-1, channels, image_size_dct, image_size_dct)
            else:
                x = x.view(-1, channels, image_size, image_size)

            if use_gpu:
                x = x.cuda()

            with torch.no_grad():
                f, scores = model(x)
            f = f.view(params.train_n_way, params.n_shot + valmodel.n_query,
                       -1)
            scores = valmodel.set_forward_adaptation(f.cpu())
            acc = []
            for each_score in scores:
                pred = each_score.data.cpu().numpy().argmax(axis=1)
                y = np.repeat(range(5), 15)
                acc.append(np.mean(pred == y) * 100)
            acc_all1.append(acc[0])
            acc_all2.append(acc[1])
            acc_all3.append(acc[2])

        print('Test Acc at 100= %4.2f%%' % (np.mean(acc_all1)))
        print('Test Acc at 200= %4.2f%%' % (np.mean(acc_all2)))
        print('Test Acc at 300= %4.2f%%' % (np.mean(acc_all3)))

        if np.mean(acc_all3) > val_acc_best:
            val_acc_best = np.mean(acc_all3)
            bestfile = os.path.join(params.checkpoint_dir, 'best.tar')
            torch.save(
                {
                    'epoch': epoch,
                    'state': model.state_dict(),
                    'rotate': rotate_classifier.state_dict()
                }, bestfile)

    return model
Пример #2
0
def train_baseline(base_loader, base_loader_test, val_loader, model,
                   start_epoch, stop_epoch, params, tmp):
    if params.dct_status:
        channels = params.channels
    else:
        channels = 3

    val_acc_best = 0.0

    if not os.path.isdir(params.checkpoint_dir):
        os.makedirs(params.checkpoint_dir)

    if path.exists(params.checkpoint_dir + '/val_' + params.dataset + '.pt'):
        loader = torch.load(params.checkpoint_dir + '/val_' + params.dataset +
                            '.pt')
    else:
        loader = []
        for ii, (x, _) in enumerate(val_loader):
            loader.append(x)
            #print("head of train_dct: ", x.shape)
        torch.save(loader,
                   params.checkpoint_dir + '/val_' + params.dataset + '.pt')

    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = optim.Adam(model.parameters())
    print("stop_epoch", start_epoch, stop_epoch)
    for epoch in range(start_epoch, stop_epoch):
        print('\nEpoch: %d' % epoch)
        model.train()
        train_loss = 0
        reg_loss = 0
        correct = 0
        correct1 = 0.0
        total = 0

        for batch_idx, (input_var, target_var) in enumerate(base_loader):
            if use_gpu:
                input_var, target_var = input_var.cuda(), target_var.cuda()
            input_dct_var, target_var = Variable(input_var), Variable(
                target_var)
            f, outputs = model.forward(input_dct_var)
            loss = criterion(outputs, target_var)
            train_loss += loss.data.item()
            _, predicted = torch.max(outputs.data, 1)
            total += target_var.size(0)
            correct += predicted.eq(target_var.data).cpu().sum()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if batch_idx % 50 == 0:
                print(
                    '{0}/{1}'.format(batch_idx, len(base_loader)),
                    'Loss: %.3f | Acc: %.3f%%  ' %
                    (train_loss / (batch_idx + 1), 100. * correct / total))

        if not os.path.isdir(params.checkpoint_dir):
            os.makedirs(params.checkpoint_dir)

        if (epoch % params.save_freq == 0) or (epoch == stop_epoch - 1):
            outfile = os.path.join(params.checkpoint_dir,
                                   '{:d}.tar'.format(epoch))
            torch.save({'epoch': epoch, 'state': model.state_dict()}, outfile)

        model.eval()
        with torch.no_grad():
            test_loss = 0
            correct = 0
            total = 0
            for batch_idx, (inputs, targets) in enumerate(base_loader_test):
                if use_gpu:
                    inputs, targets = inputs.cuda(), targets.cuda()
                inputs, targets = Variable(inputs), Variable(targets)
                f, outputs = model.forward(inputs)
                loss = criterion(outputs, targets)
                test_loss += loss.data.item()
                _, predicted = torch.max(outputs.data, 1)
                total += targets.size(0)
                correct += predicted.eq(targets.data).cpu().sum()

            print('Loss: %.3f | Acc: %.3f%%' %
                  (test_loss / (batch_idx + 1), 100. * correct / total))
        torch.cuda.empty_cache()

        valmodel = BaselineFinetune(model_dict[params.model],
                                    params.train_n_way,
                                    params.n_shot,
                                    loss_type='dist')
        valmodel.n_query = 15
        acc_all1, acc_all2, acc_all3 = [], [], []
        for i, x in enumerate(loader):
            # print("len of loader: ",len(loader))
            # print("shape of x: ",x.shape)
            if params.dct_status:
                x = x.view(-1, channels, image_size_dct, image_size_dct)
            else:
                x = x.view(-1, channels, image_size, image_size)

            if use_gpu:
                x = x.cuda()

            with torch.no_grad():
                f, scores = model(x)
            f = f.view(params.train_n_way, params.n_shot + valmodel.n_query,
                       -1)
            scores = valmodel.set_forward_adaptation(f.cpu())
            acc = []
            for each_score in scores:
                pred = each_score.data.cpu().numpy().argmax(axis=1)
                y = np.repeat(range(5), 15)
                acc.append(np.mean(pred == y) * 100)
            acc_all1.append(acc[0])
            acc_all2.append(acc[1])
            acc_all3.append(acc[2])

        print('Test Acc at 100= %4.2f%%' % (np.mean(acc_all1)))
        print('Test Acc at 200= %4.2f%%' % (np.mean(acc_all2)))
        print('Test Acc at 300= %4.2f%%' % (np.mean(acc_all3)))

        if np.mean(acc_all3) > val_acc_best:
            val_acc_best = np.mean(acc_all3)
            bestfile = os.path.join(params.checkpoint_dir, 'best.tar')
            torch.save({'epoch': epoch, 'state': model.state_dict()}, bestfile)

    return model