Esempio n. 1
0
def enforce_single_threshold(categorical_results):  # Shamus O'Connor
    import numpy as np
    import utils
    single_threshold_data = {}
    thresholds = {}

    full_list = []
    max_index = 0  # Index that returns max accuracy
    acc = []

    univ_thresh = 0
    thresh_num = 500  # Num of thresholds to try
    thresh = np.linspace(0, 1, thresh_num)

    for group in categorical_results.keys():
        full_list += categorical_results[group]

    for t in thresh:
        proc_data = []  # Data of each race sample for thresholding
        proc_data = utils.apply_threshold(full_list, t)
        acc.append(utils.get_num_correct(proc_data) /
                   len(proc_data))  # ACCURACY
#         acc.append(utils.apply_financials(proc_data))                 # COST

    max_index = acc.index(max(acc))
    univ_thresh = thresh[max_index]

    for race in categorical_results:
        subset = categorical_results[race]
        thresholds[race] = univ_thresh
        single_threshold_data[race] = utils.apply_threshold(
            subset, thresholds[race])

    return single_threshold_data, thresholds
Esempio n. 2
0
def run_epoch(data_loader, model, opt, device, is_train=True, desc=None):
    total_loss = 0
    total_correct = 0
    start = time.time()
    for i, batch in tqdm(enumerate(data_loader),
                         total=len(data_loader),
                         desc=desc):
        images = torch.unsqueeze(batch[0], 1).to(device)
        labels = batch[1].to(device)
        with torch.set_grad_enabled(is_train):
            preds = model(images)
            loss = F.cross_entropy(preds, labels)
            if is_train:
                opt.zero_grad()
                loss.backward()
                opt.step()
            total_loss += loss
            total_correct += get_num_correct(preds, labels)
    elapsed = time.time() - start
    ms = 'average train loss ' if is_train else 'average valid loss '
    print(ms + ': {}; average correct: {}'.format(
        total_loss / len(data_loader.dataset), total_correct /
        len(data_loader.dataset)))
    print(preds.argmax(dim=1))
    return total_loss, total_correct / len(data_loader.dataset)
def train(epoch, run, mod_name=''):
    total_train_loss = 0
    total_train_correct = 0
    incorrect_classifications_train = []
    epoch_classifications_train = []
    run.model.train()
    for batch_number, (images, labels, paths) in enumerate(run.train_loader):

        # for i, (image, label, path) in enumerate(zip(images, labels, paths)):
        #     save_plot_clip_frames(image, label, path, added_info_to_path = epoch)

        if run.grayscale:
            images = torch.unsqueeze(
                images, 1).double()  # added channel dimensions (grayscale)
        else:
            images = images.float().permute(0, 4, 1, 2, 3).float()
        labels = labels.long()

        if torch.cuda.is_available():
            images, labels = images.cuda(), labels.cuda()

        run.optimizer.zero_grad(
        )  # Whenever pytorch calculates gradients it always adds it to whatever it has, so we need to reset it each batch.
        preds = run.model(images)  # Pass Batch

        loss = run.criterion(preds, labels)  # Calculate Loss
        total_train_loss += loss.item()
        loss.backward(
        )  # Calculate Gradients - the gradient is the direction we need to move towards the loss function minimum (LR will tell us how far to step)
        run.optimizer.step(
        )  # Update Weights - the optimizer is able to update the weights because we passed it the weights as an argument in line 4.

        num_correct = get_num_correct(preds, labels)
        total_train_correct += num_correct

        run.experiment.log_metric(mod_name + "Train batch accuracy",
                                  num_correct / len(labels) * 100,
                                  step=run.log_number_train)
        run.experiment.log_metric(mod_name + "Avg train batch loss",
                                  loss.item(),
                                  step=run.log_number_train)
        run.log_number_train += 1

        # print('Train: Batch number:', batch_number, 'Num correct:', num_correct, 'Accuracy:', "{:.2%}".format(num_correct/len(labels)), 'Loss:', loss.item())
        incorrect_classifications_train.append(
            get_mistakes(preds, labels, paths))
        for prediction in zip(preds, labels, paths):
            epoch_classifications_train.append(prediction)
    epoch_accuracy = calc_accuracy(epoch_classifications_train)

    run.experiment.log_metric(mod_name + "Train epoch accuracy",
                              epoch_accuracy,
                              step=epoch)
    run.experiment.log_metric(mod_name + "Avg train epoch loss",
                              total_train_loss / batch_number,
                              step=epoch)

    print('\nTrain: Epoch:', epoch, 'num correct:', total_train_correct,
          'Accuracy:',
          str(epoch_accuracy) + '%')
Esempio n. 4
0
def enforce_maximum_profit(categorical_results):  # Shamus O'Connor
    import numpy as np
    import utils
    mp_data = {}
    thresholds = {}

    for race in categorical_results:
        subset = categorical_results[race]

        max_index = 0  # Index that returns max accuracy
        acc = []

        thresh_num = len(subset)  # Num of thresholds to try
        thresh = np.linspace(0, 1, thresh_num)

        for t in thresh:
            proc_data = []  # Data of each race sample for thresholding
            proc_data = utils.apply_threshold(subset, t)
            acc.append(utils.get_num_correct(proc_data) /
                       len(proc_data))  # ACCURACY
#             acc.append(utils.apply_financials(proc_data,True))                 # COST

        max_index = acc.index(max(acc))
        thresholds[race] = thresh[max_index]

        mp_data[race] = utils.apply_threshold(subset, thresholds[race])

    return mp_data, thresholds
Esempio n. 5
0
                                             download=True,
                                             transform=transforms.Compose(
                                                 [transforms.ToTensor()]))

torch.set_grad_enabled(False)
network = networkClass.Network()
network.load_state_dict(
    torch.load(".\\trained\\batch_size=10 lr=0.001 shuffle=True epochs=10.pt"))

batch_size = 100
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size)

images, labels = next(iter(test_loader))

total_loss = 0
total_correct = 0

prediction_loader = torch.utils.data.DataLoader(test_set, batch_size=10000)
test_preds = util.get_all_preds(network, prediction_loader)
preds_correct = util.get_num_correct(test_preds, test_set.targets)

util.print_training_results('N/A', preds_correct, 'N/A', len(test_loader))

cm = confusion_matrix(test_set.targets, test_preds.argmax(dim=1))

names = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal',
         'Shirt', 'Sneaker', 'Bag', 'Ankle boot')

plt.figure(figsize=(10, 10))
util.plot_confusion_matrix(cm, names)
Esempio n. 6
0
def test_cond_nets(model_fx, model_set, 
                    test_loader, test_sampler,
                    loss_function_fx=nn.CrossEntropyLoss(), loss_function_set=nn.L1Loss(), 
                    device='cpu'):
    model_fx.eval()
    model_set.eval()
    test_set_size = len(test_sampler)
    total_loss_fx = 0
    total_loss_set = 0
    total_loss = 0
    total_correct_fx = 0
    total_correct_set = 0
    total_correct = 0
    results = []
    
    with torch.no_grad():
        for batch_idx, data in enumerate(test_loader):
            mels, labels, settings, filenames, indeces = data

            mels = mels.to(device)
            labels = labels.to(device)
            settings = settings.to(device)

            # predictions and loss for FxNet 
            preds_fx = model_fx(mels)
            loss_fx = loss_function_fx(preds_fx, labels)

            total_loss_fx += loss_fx.item()
            correct_fx = utils.get_num_correct_labels(preds_fx, labels)
            total_correct_fx += correct_fx

            # predictions and loss for SettingsNet
            cond_set = preds_fx.argmax(dim=1) # calculate labels for conditioning of setnet 
            preds_set = model_set(mels, cond_set)  # pass batch and labels for conditioning
            loss_set = loss_function_set(preds_set, settings)  # calculate loss

            total_loss_set += loss_set.item()
            correct_set = utils.get_num_correct_settings(preds_set, settings)
            total_correct_set += correct_set

            # loss for both networks
            loss = loss_fx + loss_set
            total_loss += loss.item()

            correct = utils.get_num_correct(preds_fx, preds_set, labels, settings)
            total_correct += correct
            
            for idx, filename in enumerate(filenames):
                results.append(
                    (indeces[idx].item(), 
                    filename, 
                    preds_fx[idx].argmax().item(),
                    labels[idx].item(),
                    np.round(preds_set[idx].detach().numpy(), 3),
                    np.round(settings[idx].detach().numpy(), 3)))
    
    print('====> Test Loss: {:.4f}'
                '\t Avg Loss: {:.4f}'
                '\t Fx Loss: {:.4f}'
                '\t Set Loss: {:.4f}'
                '\n\t\tCorrect: {:.0f}/{:.0f}'
                '\tFx Correct: {:.0f}/{}'
                '\tSet Correct: {:.0f}/{}'
                '\n\t\tPercentage Correct: {:.2f}'
                '\tPercentage Fx Correct: {:.2f}'
                '\tPercentage Set Correct: {:.2f}'.format(
                    total_loss,
                    total_loss / test_set_size,
                    total_loss_fx,
                    total_loss_set,
                    total_correct,
                    test_set_size,
                    total_correct_fx,
                    test_set_size,
                    total_correct_set,
                    test_set_size,
                    100 * total_correct / test_set_size,
                    100 * total_correct_fx / test_set_size,
                    100 * total_correct_set / test_set_size))
    
    return total_loss, total_correct, results
Esempio n. 7
0
def train_cond_nets(model_fx, model_set, 
                    optimizer_fx, optimizer_set, 
                    train_loader, train_sampler, epoch,
                    loss_function_fx=nn.CrossEntropyLoss(), loss_function_set=nn.L1Loss(), 
                    device='cpu'):
    model_fx.train()
    model_set.train()
    train_set_size = len(train_sampler)
    total_loss_fx = 0
    total_loss_set = 0
    total_loss = 0
    total_correct_fx = 0
    total_correct_set = 0
    total_correct = 0
    results = []

    for batch_idx, data in enumerate(train_loader):
        mels, labels, settings, filenames, indeces = data
        
        mels = mels.to(device)
        labels = labels.to(device)
        settings = settings.to(device)
        
        # predictions, loss and gradient for FxNet 
        preds_fx = model_fx(mels)
        loss_fx = loss_function_fx(preds_fx, labels)

        optimizer_fx.zero_grad()
        loss_fx.backward()
        optimizer_fx.step()

        total_loss_fx += loss_fx.item()
        correct_fx = utils.get_num_correct_labels(preds_fx, labels)
        total_correct_fx += correct_fx
        
        # predictions, loss and gradient for SettingsNet
        cond_set = preds_fx.argmax(dim=1) # calculate labels for conditioning of setnet
        preds_set = model_set(mels, cond_set)  # pass batch and labels for conditioning
        loss_set = loss_function_set(preds_set, settings)  # calculate loss

        optimizer_set.zero_grad()
        loss_set.backward()
        optimizer_set.step()

        total_loss_set += loss_set.item()
        correct_set = utils.get_num_correct_settings(preds_set, settings)
        total_correct_set += correct_set

        # predictions and loss for both networks
        loss = loss_fx + loss_set
        total_loss += loss.item()

        correct = utils.get_num_correct(preds_fx, preds_set, labels, settings)
        total_correct += correct
        
        for idx, filename in enumerate(filenames):
            results.append(
                (indeces[idx].item(), 
                filename, 
                preds_fx[idx].argmax().item(),
                labels[idx].item(),
                np.round(preds_set[idx].detach().numpy(), 3),
                np.round(settings[idx].detach().numpy(), 3)))
    
        if batch_idx > 0 and batch_idx % 50 == 0:
            print('Train Epoch: {}\t[{}/{} ({:.0f}%)]\tTotal Loss: {:.4f}\tAvg Loss: {:.4f}'.format(
                        epoch, # epoch
                        batch_idx * len(labels), 
                        train_set_size,
                        100. * batch_idx / len(train_loader), # % completion
                        total_loss,
                        total_loss / (batch_idx * len(labels))))

    print('====> Epoch: {}'
                        '\tTotal Loss: {:.4f}'
                        '\t Avg Loss: {:.4f}'
                        '\t Fx Loss: {:.4f}'
                        '\t Set Loss: {:.4f}'
                        '\n\t\tCorrect: {:.0f}/{}'
                        '\tFx Correct: {:.0f}/{}'
                        '\tSet Correct: {:.0f}/{}'
                        '\n\t\tPercentage Correct: {:.2f}'
                        '\tPercentage Fx Correct: {:.2f}'
                        '\tPercentage Set Correct: {:.2f}'.format(
                            epoch,
                            total_loss,
                            total_loss / train_set_size,
                            total_loss_fx,
                            total_loss_set,
                            total_correct,
                            train_set_size,
                            total_correct_fx,
                            train_set_size,
                            total_correct_set,
                            train_set_size,
                            100 * total_correct / train_set_size,
                            100 * total_correct_fx / train_set_size,
                            100 * total_correct_set / train_set_size))
    
    return total_loss, total_correct, results
Esempio n. 8
0
def val_multi_net(model, val_loader, val_sampler, loss_function_fx=nn.CrossEntropyLoss(), loss_function_set=nn.L1Loss(), device='cpu'):
    model.eval()
    val_set_size = len(val_sampler)
    total_loss_fx = 0
    total_loss_set = 0
    total_loss = 0
    total_correct_fx = 0
    total_correct_set = 0
    total_correct = 0
    results = []

    with torch.no_grad():
        for batch_idx, data in enumerate(val_loader):
            mels, labels, settings, filenames, indeces = data
            
            mels = mels.to(device)
            labels = labels.to(device)
            settings = settings.to(device)
            
            preds_fx, preds_set = model(mels)
            loss_fx = loss_function_fx(preds_fx, labels)
            loss_set = loss_function_set(preds_set, settings)
            loss = loss_fx + loss_set

            total_loss_fx += loss_fx.item()
            total_loss_set += loss_set.item()
            total_loss += loss.item()
            correct_fx = utils.get_num_correct_labels(preds_fx, labels)
            correct_set = utils.get_num_correct_settings(preds_set, settings)
            correct = utils.get_num_correct(preds_fx, preds_set, labels, settings)
            total_correct_fx += correct_fx
            total_correct_set += correct_set
            total_correct += correct
            
            for idx, filename in enumerate(filenames):
                results.append(
                    (indeces[idx].item(), 
                    filename, 
                    preds_fx[idx].argmax().item(),
                    labels[idx].item(),
                    np.round(preds_set[idx].detach().numpy(), 3),
                    np.round(settings[idx].detach().numpy(), 3)))
    
    print('====> Val Loss: {:.4f}'
                '\t Avg Loss: {:.4f}'
                '\t Fx Loss: {:.4f}'
                '\t Set Loss: {:.4f}'
                '\n\t\tCorrect: {:.0f}/{:.0f}'
                '\tFx Correct: {:.0f}/{}'
                '\tSet Correct: {:.0f}/{}'
                '\n\t\tPercentage Correct: {:.2f}'
                '\tPercentage Fx Correct: {:.2f}'
                '\tPercentage Set Correct: {:.2f}'.format(
                    total_loss,
                    total_loss / val_set_size,
                    total_loss_fx,
                    total_loss_set,
                    total_correct,
                    val_set_size,
                    total_correct_fx,
                    val_set_size,
                    total_correct_set,
                    val_set_size,
                    100 * total_correct / val_set_size,
                    100 * total_correct_fx / val_set_size,
                    100 * total_correct_set / val_set_size))
    
    return total_loss, total_correct, results
Esempio n. 9
0
def train_multi_net(model, optimizer, train_loader, train_sampler, epoch, 
                    loss_function_fx=nn.CrossEntropyLoss(), loss_function_set=nn.L1Loss(), device='cpu'):
    model.train()
    train_set_size = len(train_sampler)
    total_loss_fx = 0
    total_loss_set = 0
    total_loss = 0
    total_correct_fx = 0
    total_correct_set = 0
    total_correct = 0
    results = []

    for batch_idx, data in enumerate(train_loader):
        mels, labels, settings, filenames, indeces = data
        
        mels = mels.to(device)
        labels = labels.to(device)
        settings = settings.to(device)
        
        preds_fx, preds_set = model(mels)

        loss_fx = loss_function_fx(preds_fx, labels)
        loss_set = loss_function_set(preds_set, settings)
        loss = loss_fx + loss_set

        optimizer.zero_grad() # zero gradients otherwise get accumulated
        loss.backward() # calculate gradient
        optimizer.step() # update weights

        total_loss_fx += loss_fx.item()
        total_loss_set += loss_set.item()
        total_loss += loss.item()
        correct_fx = utils.get_num_correct_labels(preds_fx, labels)
        correct_set = utils.get_num_correct_settings(preds_set, settings)
        correct = utils.get_num_correct(preds_fx, preds_set, labels, settings)
        total_correct_fx += correct_fx
        total_correct_set += correct_set
        total_correct += correct
        
        for idx, filename in enumerate(filenames):
            results.append(
                (indeces[idx].item(), 
                filename, 
                preds_fx[idx].argmax().item(),
                labels[idx].item(),
                np.round(preds_set[idx].detach().numpy(), 3),
                np.round(settings[idx].detach().numpy(), 3)))
        
        if batch_idx > 0 and batch_idx % 50 == 0:
            print('Train Epoch: {}\t[{}/{} ({:.0f}%)]\tTotal Loss: {:.4f}\tAvg Loss: {:.4f}'.format(
                        epoch, # epoch
                        batch_idx * len(labels), 
                        train_set_size,
                        100. * batch_idx / len(train_loader), # % completion
                        total_loss,
                        total_loss / (batch_idx * len(labels))))

    print('====> Epoch: {}'
                        '\tTotal Loss: {:.4f}'
                        '\t Avg Loss: {:.4f}'
                        '\t Fx Loss: {:.4f}'
                        '\t Set Loss: {:.4f}'
                        '\n\t\tCorrect: {:.0f}/{}'
                        '\tFx Correct: {:.0f}/{}'
                        '\tSet Correct: {:.0f}/{}'
                        '\n\t\tPercentage Correct: {:.2f}'
                        '\tPercentage Fx Correct: {:.2f}'
                        '\tPercentage Set Correct: {:.2f}'.format(
                            epoch,
                            total_loss,
                            total_loss / train_set_size,
                            total_loss_fx,
                            total_loss_set,
                            total_correct,
                            train_set_size,
                            total_correct_fx,
                            train_set_size,
                            total_correct_set,
                            train_set_size,
                            100 * total_correct / train_set_size,
                            100 * total_correct_fx / train_set_size,
                            100 * total_correct_set / train_set_size))
    
    return total_loss, total_correct, results
def evaluate(epoch, run, mod_name=''):
    incorrect_classifications_val = []
    total_val_loss = 0
    total_val_correct = 0
    best_val_acc = 0
    epoch_classifications_val = []
    run.model.eval()
    with torch.no_grad():
        for batch_number, (images, labels, paths) in enumerate(run.val_loader):

            if run.grayscale:
                images = torch.unsqueeze(
                    images, 1).double()  # added channel dimensions (grayscale)
            else:
                images = images.float().permute(0, 4, 1, 2, 3).float()
            labels = labels.long()

            if torch.cuda.is_available():
                images, labels = images.cuda(), labels.cuda()

            preds = run.model(images)  # Pass Batch
            loss = run.criterion(preds, labels)  # Calculate Loss
            total_val_loss += loss.item()

            num_correct = get_num_correct(preds, labels)
            total_val_correct += num_correct

            run.experiment.log_metric(mod_name + "Val batch accuracy",
                                      num_correct / len(labels) * 100,
                                      step=run.log_number_val)
            run.experiment.log_metric(mod_name + "Avg val batch loss",
                                      loss.item(),
                                      step=run.log_number_val)
            run.log_number_val += 1

            # print('Val: Batch number:', batch_number, 'Num correct:', num_correct, 'Accuracy:', "{:.2%}".format(num_correct / len(labels)), 'Loss:', loss.item())
            # print_mistakes(preds, labels, paths)

            incorrect_classifications_val.append(
                get_mistakes(preds, labels, paths))

            for prediction in zip(preds, labels, paths):
                epoch_classifications_val.append(prediction)

        epoch_accuracy = calc_accuracy(epoch_classifications_val)

        run.experiment.log_metric(mod_name + "Val epoch accuracy",
                                  epoch_accuracy,
                                  step=epoch)
        run.experiment.log_metric(mod_name + "Avg val epoch loss",
                                  total_val_loss / batch_number,
                                  step=epoch)
        print('Val Epoch:', epoch, 'num correct:', total_val_correct,
              'Accuracy:',
              str(epoch_accuracy) + '%')

    is_best = (epoch_accuracy > run.best_val_acc) | (
        (epoch_accuracy >= run.best_val_acc) &
        (total_val_loss / batch_number < run.best_val_loss))
    if is_best:
        print("Best run so far! updating params...")
        run.best_val_acc = epoch_accuracy
        run.best_val_loss = total_val_loss / batch_number
        run.best_model_preds = epoch_classifications_val
        run.best_model_mistakes = incorrect_classifications_val
    save_checkpoint(
        {
            'epoch': epoch + 1,
            'state_dict': run.model.state_dict(),
            'best_acc1': run.best_val_acc,
            'optimizer': run.optimizer.state_dict(),
        }, is_best)

    # Step lr_scheduler
    run.lr_scheduler.step()
Esempio n. 11
0
optimizer = optim.Adam(net.parameters(), lr=0.01)
train_loader = load_datasets(1)
# images, labels = next(iter(train_loader))
# grid = torchvision.utils.make_grid(images)
# tb.add_image("image", grid)
# tb.add_graph(net, grid)
torch.distributed.init_process_group(backend='nccl')
# net = torch.nn.parallel.DistributedDataParallel(net)
local_rank = torch.distributed.get_rank()

epoch = 6
for e in range(epoch):
    total_loss = 0
    total_correct = 0
    for batch in train_loader:
        images, labels = batch
        preds = net(images)
        loss = F.cross_entropy(preds, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        total_correct += get_num_correct(preds, labels)
    print("epoch", e, "total_loss", total_loss, "total_correct", total_correct)
    tb.add_scalar("loss", total_loss, e)
    tb.add_scalar("Number Correct", total_correct, e)
    # tb.add_scalar("Accuracy")
    tb.add_histogram("conv1.bias", net.conv1.bias, e)
    tb.add_histogram("conv1.weight", net.conv1.weight, e)
    tb.add_histogram("conv1.weight.grad", net.conv1.weight.grad, e)
        for batch in train_loader:  # Get Batch
            images, labels = batch

            preds = network(images)  # Pass Batch
            loss = F.cross_entropy(preds, labels)  # Calculate Loss

            #print('loss:', loss.item())
            #print(get_num_correct(preds, labels))

            optimizer.zero_grad()
            loss.backward()  # Calculate Gradients
            optimizer.step()  # Update Weights

            total_loss += loss.item() * batch_size
            total_correct += util.get_num_correct(preds, labels)

        if use_tensorboard:
            util.add_scalars_to_tensorboard(tb, epoch, total_correct,
                                            total_loss, len(train_set))
            util.add_histograms_to_tensorboard(network, tb, epoch)

        util.print_training_results(epoch, total_correct, total_loss,
                                    len(train_set))

    torch.save(
        network.state_dict(), ".\\trained\\" + comment.strip() + " epochs=" +
        str(epoch_length) + ".pt")

if use_tensorboard:
    tb.close()
Esempio n. 13
0
def train_supernet(model, args, *, bn_process=False, all_iters=None):
    logging.info("start warmup training...")

    optimizer = args.optimizer
    loss_function = args.loss_function
    scheduler = args.scheduler
    train_loader = args.train_loader

    t1 = time.time()

    model.train()

    if bn_process:
        adjust_bn_momentum(model, all_iters)

    all_iters += 1
    d_st = time.time()

    # print(model)

    total_correct = 0

    for ii, (data, target) in enumerate(train_loader):
        target = target.type(torch.LongTensor)
        data, target = data.cuda(args.gpu), target.cuda(args.gpu)
        data_time = time.time() - d_st

        optimizer.zero_grad()

        # 一个批次
        output = model(data, max_arc_rep)
        loss = loss_function(output, target)

        loss.backward()

        for p in model.parameters():
            if p.grad is not None and p.grad.sum() == 0:
                p.grad = None

        total_correct += get_num_correct(output, target)

        torch.nn.utils.clip_grad_norm_(model.parameters(), 5)

        if ii % 6 == 0:  # 50,000 / bs(4096)=12
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            logging.info("warmup batch acc1: {:.6f} lr: {:.6f}".format(
                acc1.item(),
                scheduler.get_last_lr()[0]))

            writer.add_scalar(
                "WTrain/Loss", loss.item(),
                all_iters * len(train_loader) * args.batch_size + ii)
            writer.add_scalar(
                "WTrain/acc1", acc1.item(),
                all_iters * len(train_loader) * args.batch_size + ii)
            writer.add_scalar(
                "WTrain/acc5", acc5.item(),
                all_iters * len(train_loader) * args.batch_size + ii)

        optimizer.step()

    writer.add_scalar("Accuracy",
                      total_correct / (len(train_loader) * args.batch_size),
                      all_iters)

    writer.add_histogram("first_conv.weight", model.module.first_conv.weight,
                         all_iters)

    writer.add_histogram("layer1[0].weight",
                         model.module.layer1[0].body[0].weight, all_iters)

    scheduler.step()

    top1, top5 = accuracy(output, target, topk=(1, 5))

    if True:
        printInfo = 'TRAIN EPOCH {}: lr = {:.6f},\tloss = {:.6f},\t'.format(all_iters, scheduler.get_last_lr()[0], loss.item()) + \
                    'Top-1 acc = {:.5f}%,\t'.format(top1.item()) + \
                    'Top-5 acc = {:.5f}%,\t'.format(top5.item()) + \
                    'data_time = {:.5f},\ttrain_time = {:.5f}'.format(
                        data_time, (time.time() - t1))

        logging.info(printInfo)
        t1 = time.time()

    if all_iters % args.save_interval == 0:
        save_checkpoint({
            'state_dict': model.state_dict(),
        }, all_iters)

    return all_iters
Esempio n. 14
0
def train_subnet(model,
                 args,
                 *,
                 bn_process=False,
                 all_iters=None,
                 arch_loader=None):
    logging.info("start subnet training...")
    optimizer = args.optimizer
    loss_function = args.loss_function
    scheduler = args.scheduler
    train_loader = args.train_loader

    t1 = time.time()

    model.train()

    if bn_process:
        adjust_bn_momentum(model, all_iters)

    all_iters += 1
    d_st = time.time()

    total_correct = 0

    for data, target in train_loader:
        target = target.type(torch.LongTensor)
        data, target = data.cuda(args.gpu), target.cuda(args.gpu)
        data_time = time.time() - d_st
        optimizer.zero_grad()

        # fair_arc_list = arch_loader.generate_fair_batch()
        fair_arc_list = arch_loader.generate_niu_fair_batch()
        # fair_arc_list = arch_loader.get_random_batch(25)

        for ii, arc in enumerate(fair_arc_list):
            # 全部架构
            output = model(data, arch_loader.convert_list_arc_str(arc))
            loss = loss_function(output, target)

            loss_reduce = reduce_tensor(loss, 0, args.world_size)

            loss.backward()

            for p in model.parameters():
                if p.grad is not None and p.grad.sum() == 0:
                    p.grad = None

            total_correct += get_num_correct(output, target)

            if ii % 15 == 0 and args.local_rank == 0:
                acc1, acc5 = accuracy(output, target, topk=(1, 5))
                logging.info(
                    "epoch: {:4d} \t acc1:{:.4f} \t acc5:{:.4f} \t loss:{:.4f}"
                    .format(all_iters, acc1.item(), acc5.item(), loss.item()))

                writer.add_scalar(
                    "Train/Loss", loss.item(),
                    all_iters * len(train_loader) * args.batch_size + ii)
                writer.add_scalar(
                    "Train/acc1", acc1.item(),
                    all_iters * len(train_loader) * args.batch_size + ii)
                writer.add_scalar(
                    "Train/acc5", acc5.item(),
                    all_iters * len(train_loader) * args.batch_size + ii)

    # 16 when using Fair sampling strategy
    if args.local_rank == 0:
        writer.add_scalar(
            "Accuracy",
            total_correct / (len(train_loader) * args.batch_size * 16),
            all_iters)
        writer.add_histogram("first_conv.weight",
                             model.module.first_conv.weight, all_iters)

        writer.add_histogram("layer1[0].weight",
                             model.module.layer1[0].body[0].weight, all_iters)

    torch.nn.utils.clip_grad_norm_(model.parameters(), 5)

    optimizer.step()
    scheduler.step()

    if all_iters % args.save_interval == 0 and args.local_rank == 0:
        save_checkpoint({
            'state_dict': model.state_dict(),
        }, all_iters)

    return all_iters