Пример #1
0
def test_unknown_station():
    """
    Scenario:

    1. Request a temperature prediction for an unknown station

    * Acceptance criteria:
    - Server should reject the request
    """
    # Request a prediction
    err, _ = get_pred("unknown", 'temperature')
    assert err == 404, "It should have failed, 'unknown' does not exist"
Пример #2
0
def test_basic_pred(nb_samples, timeframe):
    """
    Scenario:

    1. Create a station
    2. Put nb_samples samples shifted by one hour
    3. Request a temperature prediction

    * Acceptance criteria:
    - if temperatures for a minimum duration of a day have not been reported the
      server should reject the request
    - else it should return a valid temperature
    """
    # Compute the expected prediction date
    pred_date = timeframe + datetime.timedelta(hours=1)

    # Create a station
    name = random_name()
    ret = add_station(name)
    assert ret.status_code == 200

    # Put measures
    for idx in range(nb_samples):
        # Data
        temperature = random.randint(-50, 60)
        c_date = timeframe + datetime.timedelta(hours=-idx)  # minus 1 hour
        str_time = c_date.strftime("%Y-%m-%dT%H:%M:%d+00:00")
        data = {"temperature": temperature, "date": "{0}".format(str_time)}
        # Put updated measures
        ret = put_measures(name, data)
        assert ret.status_code == 200

    # Request a prediction
    err, results = get_pred(name, 'temperature')

    if nb_samples >= 24:
        # Ensure prediction is in a valid range
        assert -50 <= int(results.get('value', 0)) <= 60

        # Ensure date prediction equal to last sample +1 hour
        print results
        assert results.get('date') == pred_date.strftime("%Y-%m-%dT%H:%M:%d+00:00")

    else:
        msg = "It should have failed `Not enough samples to predict`"
        assert err == 400, msg
Пример #3
0
if conf.device.type == 'cpu':
    learner.load_state(conf, 'cpu_final.pth', True, True)
else:
    learner.load_state(conf, 'final.pth', True, True)
learner.model.eval()
print('learner loaded')

diff1, similar1 = calc_diff_and_similar(all_inter_pairs, all_intra_pairs, conf,
                                        learner, findDistance)
diff2, similar2 = calc_diff_and_similar(all_inter_pairs, all_intra_pairs, conf,
                                        learner, findDistance_cos)

# import seaborn as sns

# sns.set(color_codes=True)
# sns.distplot(similar1, hist=False)
# sns.distplot(diff1, hist=False)

range1 = [1, 1.2, 1.4]
range2 = [0.4, 0.7, 1]

step = 0.1
import numpy as np
y_true = np.concatenate(
    (np.zeros(len(diff1), dtype=bool), np.ones(len(similar1), dtype=bool)),
    axis=0)

y_pred = get_pred(diff1, similar1, 1.2)

# print(y_pred)
def main():
    # Settings
    parser = argparse.ArgumentParser(description='PyTorch CIFAR-100')
    parser.add_argument('--batch_size',
                        type=int,
                        default=128,
                        help='input batch size for training')
    parser.add_argument('--epochs',
                        type=int,
                        default=100,
                        help='number of epochs to train')
    parser.add_argument('--lr', type=float, default=1e-6, help='learning rate')
    parser.add_argument('--dp', type=float, default=0.2, help='dropout rate')
    parser.add_argument(
        '--aug',
        type=str,
        default='strong',
        help='Type of data augmentation {none, standard, strong}')
    parser.add_argument('--noise_pattern',
                        type=str,
                        default='uniform',
                        help='Noise pattern')
    parser.add_argument('--noise_rate',
                        type=float,
                        default=0.2,
                        help='Noise rate')
    parser.add_argument('--val_size',
                        type=int,
                        default=5000,
                        help='size of (noisy) validation set')
    parser.add_argument('--save_model',
                        action='store_true',
                        default=False,
                        help='For Saving the current Model')
    parser.add_argument('--teacher_path',
                        type=str,
                        default=None,
                        help='Path of the teacher model')
    parser.add_argument('--init_path',
                        type=str,
                        default=None,
                        help='DMI requires a pretrained model to initialize')
    parser.add_argument('--gpu_id',
                        type=int,
                        default=0,
                        help='index of gpu to use')
    parser.add_argument('--test_batch_size',
                        type=int,
                        default=200,
                        help='input batch size for testing')
    parser.add_argument('--seed',
                        type=int,
                        default=0,
                        help='random seed (default: 0)')
    args = parser.parse_args()

    if args.teacher_path is None:
        exp_name = 'dmi_cifar100_{}{:.1f}_dp{:.1f}_aug{}_seed{}'.format(
            args.noise_pattern, args.noise_rate, args.dp, args.aug, args.seed)
    else:
        exp_name = 'dmi_cifar100_{}{:.1f}_dp{:.1f}_aug{}_student_seed{}'.format(
            args.noise_pattern, args.noise_rate, args.dp, args.aug, args.seed)
    logpath = '{}.txt'.format(exp_name)
    log(logpath, 'Settings: {}\n'.format(args))

    torch.manual_seed(args.seed)
    device = torch.device(
        'cuda:' + str(args.gpu_id) if torch.cuda.is_available() else 'cpu')

    # Datasets
    root = './data/CIFAR100'
    num_classes = 100
    kwargs = {
        'num_workers': 4,
        'pin_memory': True
    } if torch.cuda.is_available() else {}
    if args.aug == 'standard':
        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010))
        ])
    elif args.aug == 'strong':
        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4, fill=128),
            transforms.RandomHorizontalFlip(),
            CIFAR10Policy(),
            transforms.ToTensor(),
            Cutout(
                n_holes=1, length=16
            ),  # (https://github.com/uoguelph-mlrg/Cutout/blob/master/util/cutout.py)
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010))
        ])
    else:
        train_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010))
        ])

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010))
    ])

    dataset = datasets.CIFAR100(root, train=True, download=True)
    data, label = dataset.data, dataset.targets
    label_noisy = list(
        pd.read_csv(
            os.path.join('./data/CIFAR100/label_noisy',
                         args.noise_pattern + str(args.noise_rate) +
                         '.csv'))['label_noisy'].values.astype(int))
    train_dataset = DATASET_CUSTOM(root,
                                   data[:-args.val_size],
                                   label_noisy[:-args.val_size],
                                   transform=train_transform)
    val_dataset = DATASET_CUSTOM(root,
                                 data[-args.val_size:],
                                 label_noisy[-args.val_size:],
                                 transform=test_transform)
    test_dataset = datasets.CIFAR100(root,
                                     train=False,
                                     transform=test_transform)

    if args.teacher_path is not None:
        teacher_model = Wide_ResNet(args.dp,
                                    num_classes=num_classes,
                                    use_log_softmax=False).to(device)
        teacher_model.load_state_dict(torch.load(args.teacher_path))
        distill_dataset = DATASET_CUSTOM(root,
                                         data[:-args.val_size],
                                         label_noisy[:-args.val_size],
                                         transform=test_transform)
        pred = get_pred(teacher_model, device, distill_dataset,
                        args.test_batch_size)
        log(
            logpath, 'distilled noise rate: {:.2f}\n'.format(
                1 -
                (np.array(label[:-args.val_size]) == pred).sum() / len(pred)))
        train_dataset.targets = pred
        del teacher_model

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               **kwargs)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.test_batch_size,
                                             shuffle=False,
                                             **kwargs)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=args.test_batch_size,
                                              shuffle=False,
                                              **kwargs)

    # Building model
    def DMI_loss(output, target):
        outputs = F.softmax(output, dim=1)
        targets = target.reshape(target.size(0), 1)
        y_onehot = torch.FloatTensor(target.size(0), num_classes)
        y_onehot.zero_()
        targets = targets.cpu()
        y_onehot.scatter_(1, targets, 1)
        y_onehot = y_onehot.transpose(0, 1).to(device)
        mat = y_onehot @ outputs
        return -1.0 * torch.log(torch.abs(torch.det(mat.float())) + 0.001)

    model = Wide_ResNet(args.dp,
                        num_classes=num_classes,
                        use_log_softmax=False).to(device)
    model.load_state_dict(torch.load(args.init_path))

    # Training
    val_best, epoch_best, test_at_best = 0, 0, 0
    for epoch in range(1, args.epochs + 1):
        t0 = time.time()
        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              momentum=0.9,
                              weight_decay=5e-4)
        _, train_acc = train(args,
                             model,
                             device,
                             train_loader,
                             optimizer,
                             epoch,
                             criterion=DMI_loss)
        _, val_acc = test(args,
                          model,
                          device,
                          val_loader,
                          criterion=F.cross_entropy)
        _, test_acc = test(args,
                           model,
                           device,
                           test_loader,
                           criterion=F.cross_entropy)
        if val_acc > val_best:
            val_best, test_at_best, epoch_best = val_acc, test_acc, epoch
            if args.save_model:
                torch.save(model.state_dict(), '{}_best.pth'.format(exp_name))

        log(
            logpath,
            'Epoch: {}/{}, Time: {:.1f}s. '.format(epoch, args.epochs,
                                                   time.time() - t0))
        log(
            logpath,
            'Train: {:.2f}%, Val: {:.2f}%, Test: {:.2f}%; Val_best: {:.2f}%, Test_at_best: {:.2f}%, Epoch_best: {}\n'
            .format(100 * train_acc, 100 * val_acc, 100 * test_acc,
                    100 * val_best, 100 * test_at_best, epoch_best))

    # Saving
    if args.save_model:
        torch.save(model.state_dict(), '{}_last.pth'.format(exp_name))
def main():
    # Settings
    parser = argparse.ArgumentParser(description='PyTorch CIFAR-100')
    parser.add_argument('--batch_size',
                        type=int,
                        default=128,
                        help='input batch size for training')
    parser.add_argument('--epochs',
                        type=int,
                        default=200,
                        help='number of epochs to train')
    parser.add_argument('--lr', type=float, default=0.1, help='learning rate')
    parser.add_argument('--dp', type=float, default=0.0, help='dropout rate')
    parser.add_argument(
        '--aug',
        type=str,
        default='standard',
        help='Type of data augmentation {none, standard, strong}')
    parser.add_argument('--noise_pattern',
                        type=str,
                        default='uniform',
                        help='Noise pattern')
    parser.add_argument('--noise_rate',
                        type=float,
                        default=0.2,
                        help='Noise rate')
    parser.add_argument('--e_warm',
                        type=int,
                        default=120,
                        help='warm-up epochs without discarding any samples')
    parser.add_argument('--val_size',
                        type=int,
                        default=5000,
                        help='size of (noisy) validation set')
    parser.add_argument('--save_model',
                        action='store_true',
                        default=False,
                        help='For Saving the current Model')
    parser.add_argument('--teacher_path',
                        type=str,
                        default=None,
                        help='Path of the teacher model')
    parser.add_argument('--gpu_id',
                        type=int,
                        default=0,
                        help='index of gpu to use')
    parser.add_argument('--test_batch_size',
                        type=int,
                        default=200,
                        help='input batch size for testing')
    parser.add_argument('--seed',
                        type=int,
                        default=0,
                        help='random seed (default: 0)')
    args = parser.parse_args()

    if args.teacher_path is None:
        exp_name = 'gce_cifar100_{}{:.1f}_dp{:.1f}_aug{}_seed{}'.format(
            args.noise_pattern, args.noise_rate, args.dp, args.aug, args.seed)
    else:
        exp_name = 'gce_cifar100_{}{:.1f}_dp{:.1f}_aug{}_student_seed{}'.format(
            args.noise_pattern, args.noise_rate, args.dp, args.aug, args.seed)
    logpath = '{}.txt'.format(exp_name)
    log(logpath, 'Settings: {}\n'.format(args))

    torch.manual_seed(args.seed)
    device = torch.device(
        'cuda:' + str(args.gpu_id) if torch.cuda.is_available() else 'cpu')

    # Datasets
    root = './data/CIFAR100'
    kwargs = {
        'num_workers': 4,
        'pin_memory': True
    } if torch.cuda.is_available() else {}
    if args.aug == 'standard':
        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010))
        ])
    elif args.aug == 'strong':
        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4, fill=128),
            transforms.RandomHorizontalFlip(),
            CIFAR10Policy(),
            transforms.ToTensor(),
            Cutout(
                n_holes=1, length=16
            ),  # (https://github.com/uoguelph-mlrg/Cutout/blob/master/util/cutout.py)
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010))
        ])
    else:
        train_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010))
        ])

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010))
    ])

    dataset = datasets.CIFAR100(root, train=True, download=True)
    data, label = dataset.data, dataset.targets
    label_noisy = list(
        pd.read_csv(
            os.path.join('./data/CIFAR100/label_noisy',
                         args.noise_pattern + str(args.noise_rate) +
                         '.csv'))['label_noisy'].values.astype(int))
    train_dataset = DATASET_CUSTOM(root,
                                   data[:-args.val_size],
                                   label_noisy[:-args.val_size],
                                   transform=train_transform)
    val_dataset = DATASET_CUSTOM(root,
                                 data[-args.val_size:],
                                 label_noisy[-args.val_size:],
                                 transform=test_transform)
    test_dataset = datasets.CIFAR100(root,
                                     train=False,
                                     transform=test_transform)

    if args.teacher_path is not None:
        teacher_model = Wide_ResNet(args.dp,
                                    num_classes=100,
                                    use_log_softmax=False).to(device)
        teacher_model.load_state_dict(torch.load(args.teacher_path))
        distill_dataset = DATASET_CUSTOM(root,
                                         data[:-args.val_size],
                                         label_noisy[:-args.val_size],
                                         transform=test_transform)
        pred = get_pred(teacher_model, device, distill_dataset,
                        args.test_batch_size)
        log(
            logpath, 'distilled noise rate: {:.2f}\n'.format(
                1 -
                (np.array(label[:-args.val_size]) == pred).sum() / len(pred)))
        train_dataset.targets = pred
        del teacher_model

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               **kwargs)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.test_batch_size,
                                             shuffle=False,
                                             **kwargs)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=args.test_batch_size,
                                              shuffle=False,
                                              **kwargs)

    # Building model
    def learning_rate(lr_init, epoch):
        optim_factor = 0
        if (epoch > 160):
            optim_factor = 3
        elif (epoch > 120):
            optim_factor = 2
        elif (epoch > 60):
            optim_factor = 1
        return lr_init * math.pow(0.2, optim_factor)

    def lq_loss(output, target, q=0.7):
        output = F.softmax(output, dim=1)
        output_i = torch.gather(output, 1, torch.unsqueeze(target, 1))
        loss = torch.mean((1 - (output_i**q)) / q)
        return loss

    def lq_loss_truncated(output, target, q=0.7, k=0.5):
        output = F.softmax(output, dim=1)
        output_i = torch.gather(output, 1, torch.unsqueeze(target, 1))

        k_repeat = torch.from_numpy(np.repeat(k, target.size(0))).type(
            torch.FloatTensor).to(device)
        weight = torch.gt(output_i,
                          k_repeat).type(torch.FloatTensor).to(device)

        loss = ((1 -
                 (output_i**q)) / q) * weight + ((1 -
                                                  (k**q)) / q) * (1 - weight)
        loss = torch.mean(loss)

        return loss

    model = Wide_ResNet(args.dp, num_classes=100,
                        use_log_softmax=False).to(device)

    # Training
    val_best, epoch_best, test_at_best = 0, 0, 0
    for epoch in range(1, args.epochs + 1):
        t0 = time.time()
        optimizer = optim.SGD(model.parameters(),
                              lr=learning_rate(args.lr, epoch),
                              momentum=0.9,
                              weight_decay=5e-4)
        if epoch > args.e_warm and epoch % 10 == 0:  # after the first learning rate change
            criterion = lq_loss_truncated
        else:
            criterion = lq_loss
        _, train_acc = train(args,
                             model,
                             device,
                             train_loader,
                             optimizer,
                             epoch,
                             criterion=criterion)
        _, val_acc = test(args,
                          model,
                          device,
                          val_loader,
                          criterion=F.cross_entropy)
        _, test_acc = test(args,
                           model,
                           device,
                           test_loader,
                           criterion=F.cross_entropy)
        if val_acc > val_best:
            val_best, test_at_best, epoch_best = val_acc, test_acc, epoch
            if args.save_model:
                torch.save(model.state_dict(), '{}_best.pth'.format(exp_name))

        log(
            logpath,
            'Epoch: {}/{}, Time: {:.1f}s. '.format(epoch, args.epochs,
                                                   time.time() - t0))
        log(
            logpath,
            'Train: {:.2f}%, Val: {:.2f}%, Test: {:.2f}%; Val_best: {:.2f}%, Test_at_best: {:.2f}%, Epoch_best: {}\n'
            .format(100 * train_acc, 100 * val_acc, 100 * test_acc,
                    100 * val_best, 100 * test_at_best, epoch_best))

    # Saving
    if args.save_model:
        torch.save(model.state_dict(), '{}_last.pth'.format(exp_name))
def main():
    # Settings
    parser = argparse.ArgumentParser(description='PyTorch Clothing1M')
    parser.add_argument('--batch_size',
                        type=int,
                        default=256,
                        help='input batch size for training')
    parser.add_argument('--epochs',
                        type=int,
                        default=10,
                        help='number of epochs to train')
    parser.add_argument('--lr',
                        type=float,
                        default=1e-3,
                        help='init learning rate')
    parser.add_argument('--save_model',
                        action='store_true',
                        default=False,
                        help='For Saving the current Model')
    parser.add_argument(
        '--use_noisy_val',
        action='store_true',
        default=False,
        help=
        'Using the noisy validation setting. By default, using the benchmark setting.'
    )
    parser.add_argument('--init_path',
                        type=str,
                        default=None,
                        help='Path of a pretrained model)')
    parser.add_argument('--teacher_path',
                        type=str,
                        default=None,
                        help='Path of the teacher model')
    parser.add_argument('--soft_targets',
                        type=bool,
                        default=True,
                        help='Use soft targets')
    parser.add_argument('--n_gpu',
                        type=int,
                        default=2,
                        help='number of gpu to use')
    parser.add_argument('--test_batch_size',
                        type=int,
                        default=256,
                        help='input batch size for testing')
    parser.add_argument('--root',
                        type=str,
                        default='data/Clothing1M/',
                        help='root of dataset')
    parser.add_argument('--seed', type=int, default=0, help='random seed')
    args = parser.parse_args()

    if args.teacher_path is None:
        exp_name = 'clothing1m_batch{}_seed{}'.format(args.batch_size,
                                                      args.seed)
    else:
        teacher_name = args.teacher_path.replace('models/', '')
        teacher_name = teacher_name[:teacher_name.find('_')]
        if 'net1' in args.teacher_path:
            teacher_name = teacher_name + 'net1'
        elif 'net2' in args.teacher_path:
            teacher_name = teacher_name + 'net2'
        if args.soft_targets:
            exp_name = 'softstudent_of_{}_clothing1m_batch{}_seed{}'.format(
                teacher_name, args.batch_size, args.seed)
        else:
            exp_name = 'student_of_{}_clothing1m_batch{}_seed{}'.format(
                teacher_name, args.batch_size, args.seed)
        if args.init_path is None:
            args.init_path = args.teacher_path

    if args.use_noisy_val:
        exp_name = 'nv_' + exp_name
    logpath = '{}.txt'.format(exp_name)
    log(logpath, 'Settings: {}\n'.format(args))

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    # soft loss
    def soft_cross_entropy(output, target):
        output = F.log_softmax(output, dim=1)
        loss = -torch.mean(torch.sum(output * target, dim=1))
        return loss

    # Datasets
    root = args.root
    num_classes = 14
    kwargs = {
        'num_workers': 32,
        'pin_memory': True
    } if torch.cuda.is_available() else {}
    train_transform = transforms.Compose([
        transforms.Resize((256)),
        transforms.RandomCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.6959, 0.6537, 0.6371),
                             (0.3113, 0.3192, 0.3214)),
    ])
    test_transform = transforms.Compose([
        transforms.Resize((256)),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize((0.6959, 0.6537, 0.6371),
                             (0.3113, 0.3192, 0.3214)),
    ])

    train_dataset = Clothing1M(root,
                               mode='train',
                               transform=train_transform,
                               use_noisy_val=args.use_noisy_val)
    val_dataset = Clothing1M(root,
                             mode='val',
                             transform=test_transform,
                             use_noisy_val=args.use_noisy_val)
    test_dataset = Clothing1M(root,
                              mode='test',
                              transform=test_transform,
                              use_noisy_val=args.use_noisy_val)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               **kwargs)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.test_batch_size,
                                             shuffle=False,
                                             **kwargs)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=args.test_batch_size,
                                              shuffle=False,
                                              **kwargs)

    if args.teacher_path is not None:
        teacher_model = resnet50(num_classes=num_classes).to(device)
        teacher_model = torch.nn.DataParallel(teacher_model,
                                              device_ids=list(range(
                                                  args.n_gpu)))
        state_dict = torch.load(args.teacher_path)
        if not list(state_dict.keys())[0][:7] == 'module.':
            state_dict = dict(('module.' + key, value)
                              for (key, value) in state_dict.items())
        teacher_model.load_state_dict(state_dict)
        distill_dataset = Clothing1M(root,
                                     mode='train',
                                     transform=test_transform,
                                     use_noisy_val=args.use_noisy_val)
        if args.soft_targets:
            pred = get_pred(teacher_model,
                            device,
                            distill_dataset,
                            args.test_batch_size,
                            num_workers=32,
                            output_softmax=True)
            train_criterion = soft_cross_entropy
        else:
            pred = get_pred(teacher_model,
                            device,
                            distill_dataset,
                            args.test_batch_size,
                            num_workers=32)
            train_criterion = F.cross_entropy
        train_dataset.targets = pred
        log(logpath, 'Get label from teacher {}.\n'.format(args.teacher_path))
        del teacher_model
    else:
        train_criterion = F.cross_entropy

    # Building model
    def learning_rate(lr_init, epoch):
        optim_factor = 0
        if (epoch > 5):
            optim_factor = 1
        return lr_init * math.pow(0.1, optim_factor)

    model = resnet50(pretrained=True)
    model.fc = nn.Linear(2048, num_classes)
    model = torch.nn.DataParallel(model.to(device),
                                  device_ids=list(range(args.n_gpu)))
    if args.init_path is not None:
        state_dict = torch.load(args.init_path)
        if not list(state_dict.keys())[0][:7] == 'module.':
            state_dict = dict(('module.' + key, value)
                              for (key, value) in state_dict.items())
        model.load_state_dict(state_dict)
        _, test_acc = test(args,
                           model,
                           device,
                           test_loader,
                           criterion=F.cross_entropy)
        log(logpath,
            'Initialized testing accuracy: {:.2f}\n'.format(100 * test_acc))
    cudnn.benchmark = True  # Accelerate training by enabling the inbuilt cudnn auto-tuner to find the best algorithm to use for your hardware.
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=0.9,
                          weight_decay=1e-3)

    # Training
    save_every_epoch = True
    if save_every_epoch:
        vals = []
        directory = 'models/' + exp_name
        if not os.path.exists(directory):
            os.makedirs(directory)

    val_best, epoch_best, test_at_best = 0, 0, 0
    for epoch in range(1, args.epochs + 1):
        t0 = time.time()
        lr = learning_rate(args.lr, epoch)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        _, train_acc = train(args,
                             model,
                             device,
                             train_loader,
                             optimizer,
                             epoch,
                             criterion=train_criterion)
        _, val_acc = test(args,
                          model,
                          device,
                          val_loader,
                          criterion=F.cross_entropy)
        _, test_acc = test(args,
                           model,
                           device,
                           test_loader,
                           criterion=F.cross_entropy)
        if val_acc > val_best:
            val_best, test_at_best, epoch_best = val_acc, test_acc, epoch
            if args.save_model:
                torch.save(model.state_dict(), '{}_best.pth'.format(exp_name))
        if save_every_epoch:
            vals.append(val_acc)
            torch.save(model.state_dict(),
                       '{}/epoch{}.pth'.format(directory, epoch))

        log(
            logpath,
            'Epoch: {}/{}, Time: {:.1f}s. '.format(epoch, args.epochs,
                                                   time.time() - t0))
        log(
            logpath,
            'Train: {:.2f}%, Val: {:.2f}%, Test: {:.2f}%; Val_best: {:.2f}%, Test_at_best: {:.2f}%, Epoch_best: {}\n'
            .format(100 * train_acc, 100 * val_acc, 100 * test_acc,
                    100 * val_best, 100 * test_at_best, epoch_best))

    if save_every_epoch:
        np.save('{}/val.npy'.format(directory), vals)
def main():
    # Settings
    parser = argparse.ArgumentParser(description='PyTorch CIFAR-100')
    parser.add_argument('--batch_size',
                        type=int,
                        default=128,
                        help='input batch size for training')
    parser.add_argument('--epochs',
                        type=int,
                        default=200,
                        help='number of epochs to train')
    parser.add_argument('--lr', type=float, default=0.1, help='learning rate')
    parser.add_argument('--dp', type=float, default=0.2, help='dropout rate')
    parser.add_argument(
        '--aug',
        type=str,
        default='strong',
        help='type of data augmentation {none, standard, strong}')
    parser.add_argument('--noise_pattern',
                        type=str,
                        default='uniform',
                        help='Noise pattern')
    parser.add_argument('--noise_rate',
                        type=float,
                        default=0.2,
                        help='Noise rate')
    parser.add_argument('--tau',
                        type=float,
                        default=0.2,
                        help='maximum discard ratio of large-loss samples')
    parser.add_argument('--e_warm',
                        type=int,
                        default=0,
                        help='warm-up epochs without discarding any samples')
    parser.add_argument('--val_size',
                        type=int,
                        default=5000,
                        help='size of (noisy) validation set')
    parser.add_argument('--save_model',
                        action='store_true',
                        default=False,
                        help='for Saving the current Model')
    parser.add_argument('--teacher_path',
                        type=str,
                        default=None,
                        help='path of the teacher model')
    parser.add_argument('--gpu_id',
                        type=int,
                        default=0,
                        help='index of gpu to use')
    parser.add_argument('--test_batch_size',
                        type=int,
                        default=200,
                        help='input batch size for testing')
    parser.add_argument('--seed',
                        type=int,
                        default=0,
                        help='random seed (default: 0)')
    args = parser.parse_args()

    if args.teacher_path is None:
        exp_name = 'ct_cifar100_{}{:.1f}_warm{}_dp{:.1f}_aug{}_seed{}'.format(
            args.noise_pattern, args.noise_rate, args.e_warm, args.dp,
            args.aug, args.seed)
    else:
        exp_name = 'ct_cifar100_{}{:.1f}_warm{}_dp{:.1f}_aug{}_student_seed{}'.format(
            args.noise_pattern, args.noise_rate, args.e_warm, args.dp,
            args.aug, args.seed)
    logpath = '{}.txt'.format(exp_name)
    log(logpath, 'Settings: {}\n'.format(args))

    torch.manual_seed(args.seed)
    device = torch.device(
        'cuda:' + str(args.gpu_id) if torch.cuda.is_available() else 'cpu')

    # Datasets
    root = './data/CIFAR100'
    kwargs = {
        'num_workers': 4,
        'pin_memory': True
    } if torch.cuda.is_available() else {}
    if args.aug == 'standard':
        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010))
        ])
    elif args.aug == 'strong':
        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4, fill=128),
            transforms.RandomHorizontalFlip(),
            CIFAR10Policy(),
            transforms.ToTensor(),
            Cutout(
                n_holes=1, length=16
            ),  # (https://github.com/uoguelph-mlrg/Cutout/blob/master/util/cutout.py)
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010))
        ])
    else:
        train_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010))
        ])

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010))
    ])

    dataset = datasets.CIFAR100(root, train=True, download=True)
    data, label = dataset.data, dataset.targets
    label_noisy = list(
        pd.read_csv(
            os.path.join('./data/CIFAR100/label_noisy',
                         args.noise_pattern + str(args.noise_rate) +
                         '.csv'))['label_noisy'].values.astype(int))
    train_dataset = DATASET_CUSTOM(root,
                                   data[:-args.val_size],
                                   label_noisy[:-args.val_size],
                                   transform=train_transform)
    val_dataset = DATASET_CUSTOM(root,
                                 data[-args.val_size:],
                                 label_noisy[-args.val_size:],
                                 transform=test_transform)
    test_dataset = datasets.CIFAR100(root,
                                     train=False,
                                     transform=test_transform)

    if args.teacher_path is not None:
        teacher_model = Wide_ResNet(args.dp, num_classes=100).to(device)
        teacher_model.load_state_dict(torch.load(args.teacher_path))
        distill_dataset = DATASET_CUSTOM(root,
                                         data[:-args.val_size],
                                         label_noisy[:-args.val_size],
                                         transform=test_transform)
        pred = get_pred(teacher_model, device, distill_dataset,
                        args.test_batch_size)
        log(
            logpath, 'distilled noise rate: {:.2f}\n'.format(
                1 -
                (np.array(label[:-args.val_size]) == pred).sum() / len(pred)))
        train_dataset.targets = pred
        del teacher_model

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               **kwargs)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.test_batch_size,
                                             shuffle=False,
                                             **kwargs)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=args.test_batch_size,
                                              shuffle=False,
                                              **kwargs)

    # Building model
    def learning_rate(lr_init, epoch):
        optim_factor = 0
        if (epoch > 160):
            optim_factor = 3
        elif (epoch > 120):
            optim_factor = 2
        elif (epoch > 60):
            optim_factor = 1
        return lr_init * math.pow(0.2, optim_factor)

    def get_keep_ratio(e, tau=args.tau, e_warm=args.e_warm):
        return 1. - tau * min(max((e - e_warm) / 10, 0), 1.)

    model1 = Wide_ResNet(args.dp, num_classes=100).to(device)
    model2 = Wide_ResNet(args.dp, num_classes=100).to(device)

    # Training
    val_best, epoch_best, test_at_best = 0, 0, 0
    for epoch in range(1, args.epochs + 1):
        t0 = time.time()
        optimizer1 = optim.SGD(model1.parameters(),
                               lr=learning_rate(args.lr, epoch),
                               momentum=0.9,
                               weight_decay=5e-4)
        optimizer2 = optim.SGD(model2.parameters(),
                               lr=learning_rate(args.lr, epoch),
                               momentum=0.9,
                               weight_decay=5e-4)
        _, train_acc1, _, train_acc2 = train_ct(args, model1, model2,
                                                optimizer1, optimizer2, device,
                                                train_loader,
                                                get_keep_ratio(epoch))
        _, val_acc1 = test(args, model1, device, val_loader)
        _, val_acc2 = test(args, model2, device, val_loader)
        _, test_acc1 = test(args, model1, device, test_loader)
        _, test_acc2 = test(args, model2, device, test_loader)
        if max(val_acc1, val_acc2) > val_best:
            index = np.argmax([val_acc1, val_acc2])
            val_best, test_at_best, epoch_best = max(
                val_acc1, val_acc2), [test_acc1, test_acc2][index], epoch
            if args.save_model:
                torch.save([model1.state_dict(),
                            model2.state_dict()][index],
                           '{}_best.pth'.format(exp_name))

        log(
            logpath,
            'Epoch: {}/{}, Time: {:.1f}s. '.format(epoch, args.epochs,
                                                   time.time() - t0))
        log(
            logpath,
            'Train1: {:.2f}%, Val1: {:.2f}%, Test1: {:.2f}%, Train2: {:.2f}%, Val2: {:.2f}%, Test2: {:.2f}%; Val_best: {:.2f}%, Test_at_best: {:.2f}%, Epoch_best: {}\n'
            .format(100 * train_acc1, 100 * val_acc1, 100 * test_acc1,
                    100 * train_acc2, 100 * val_acc2, 100 * test_acc2,
                    100 * val_best, 100 * test_at_best, epoch_best))


# wrong order 100*train_acc1, 100*train_acc2, 100*val_acc1, 100*test_acc1, 100*val_acc2, 100*test_acc2, 100*val_best, 100*test_at_best, epoch_best

# Saving
    if args.save_model:
        torch.save([model1.state_dict(),
                    model2.state_dict()][index],
                   '{}_last.pth'.format(exp_name))