Пример #1
0
def download(baseurl, parameters={}, headers={}, data={}, method=methods.GET):
    '''Download Data from an url and returns it as a String
    @param method Request method
    @param baseurl Url to download from (e.g. http://www.google.com)
    @param parameters Parameter dict to be encoded with url
    @param headers Headers dict to pass with Request
    @param data Request body
    @param method Request method
    @returns String of data from URL
    '''
    method = methods.validate(method)
    url = '?'.join([baseurl, urlencode(parameters)])
    log.debug('Downloading: ' + url)
    content = ""
    for _ in range(MAX_RETRIES):
        try:
            headers.update({USER_AGENT: USER_AGENT_STRING})
            response = requests.request(method=method,
                                        url=url,
                                        headers=headers,
                                        data=data,
                                        verify=SSL_VERIFICATION)
            content = response.content
            if not content:
                content = '{"status": %d}' % response.status_code
            break
        except Exception as err:
            if not isinstance(err, URLError):
                log.debug("Error %s during HTTP Request, abort", repr(err))
                raise  # propagate non-URLError
            log.debug("Error %s during HTTP Request, retrying", repr(err))
    else:
        raise
    return content
Пример #2
0
def get_json(baseurl, parameters={}, headers={}, data={}, method=methods.GET):
    '''Download Data from an URL and returns it as JSON
    @param url Url to download from
    @param parameters Parameter dict to be encoded with url
    @param headers Headers dict to pass with Request
    @param data Request body
    @param method Request method
    @returns JSON Object with data from URL
    '''
    method = methods.validate(method)
    jsonString = download(baseurl, parameters, headers, data, method)
    jsonDict = json.loads(jsonString)
    log.debug(json.dumps(jsonDict, indent=4, sort_keys=True))
    return jsonDict
Пример #3
0
def main():
    global args, best_prec1, best_test_prec1
    global acc1_stu1_tr, losses_stu1_tr, losses_stu1_cl_tr
    global acc1_stu2_tr, losses_stu2_tr, losses_stu2_cl_tr
    global acc1_t_tr, losses_t_tr
    global acc1_t_val, losses_t_val
    global acc1_t_test, losses_t_test
    global learning_rate, weights_cl
    args = parser.parse_args()

    # 保存训练loss
    loss_dict = {
        # 教师(模型)分类、一致性、总损失
        "loss_class": [],
        "loss_cl": [],
        "loss_total": [],
        # 学生1分类、一致、总损失
        "loss_class_1": [],
        "loss_cl_1": [],
        "loss_total_1": [],
        # 学生2分类、一致、总损失
        "loss_class_2": [],
        "loss_cl_2": [],
        "loss_total_2": [],
        # 验证、测试损失
        "val_loss": [],
        "test_loss": []
    }

    # 网络模型选择
    if args.arch == 'net13':
        print("Model: %s" % args.arch)
        student = net13(num_classes=args.num_classes, dropRatio=args.drop_rate)
        if args.model == 'ds_mt':
            student2 = net13(args.num_classes,
                             dropRatio=args.drop_rate,
                             isL2=True)
        if args.model == 'ms':
            student1 = net13(num_classes=args.num_classes,
                             dropRatio=args.drop_rate)
            student2 = net13(num_classes=args.num_classes,
                             dropRatio=args.drop_rate)
    elif args.arch == 'resnet18':
        print("Model: %s" % args.arch)
        student = resnet18(num_classes=args.num_classes)
        if args.model == 'ds_mt':
            student2 = resnet18(args.num_classes, isL2=True)
        if args.model == 'ms':
            student1 = resnet18(num_classes=args.num_classes)
            student2 = resnet18(num_classes=args.num_classes)
    else:
        assert (False)

    # 算法参数初始化
    if args.model == 'mt':
        import copy
        teacher = copy.deepcopy(student)
        teacher_model = torch.nn.DataParallel(teacher).cuda()  # 多GPU并行

    if args.model == 'mt+':
        import copy
        student1 = copy.deepcopy(student)
        student2 = copy.deepcopy(student)
        student1_model = torch.nn.DataParallel(student1).cuda()  # 多GPU并行
        student2_model = torch.nn.DataParallel(student2).cuda()  # 多GPU并行
        # student2_model = torch.nn.DataParallel(student2).cuda()  # 多GPU并行

    if args.model == 'ds_mt' or args.model == 'd-ds_mt':
        import copy
        student1 = copy.deepcopy(student)
        student1_model = torch.nn.DataParallel(student1).cuda()  # 多GPU并行
        student2_model = torch.nn.DataParallel(student2).cuda()  # 多GPU并行

    if args.model == 'ms':
        student1_model = torch.nn.DataParallel(student1).cuda()  # 多GPU并行
        student2_model = torch.nn.DataParallel(student2).cuda()  # 多GPU并行

    # if args.model == 'pso_mt+':
    #     student1_model = torch.nn.DataParallel(student1).cuda()  # 多GPU并行
    #     student2_model = torch.nn.DataParallel(student2).cuda()  # 多GPU并行
    #     # 保存最优参数
    #     best_his_param = {
    #         'local_loss': 999,
    #         'his_loss': 999,
    #     }

    student_model = torch.nn.DataParallel(student).cuda()

    # 检查点恢复
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            loss_dict = checkpoint['loss_dict']
            best_prec1 = checkpoint['best_prec1']
            student_model.load_state_dict(checkpoint['student_state_dict'])

            if args.model == 'mt':
                teacher_model.load_state_dict(checkpoint['teacher_state_dict'])
                print("=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']))
            elif args.model != 'baseline' and args.model != 'pi':
                student1_model.load_state_dict(
                    checkpoint['student1_state_dict'])
                student2_model.load_state_dict(
                    checkpoint['student2_state_dict'])
                print("=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # 优化器
    if args.optim == 'sgd' or args.optim == 'adam':
        pass
    else:
        print('Not Implemented Optimizer')
        assert (False)

    # 反传算法优化器
    if args.optim == 'adam':
        print('Using Adam optimizer')
        optimizer = torch.optim.Adam(student_model.parameters(),
                                     args.lr,
                                     betas=(0.9, 0.999),
                                     weight_decay=args.weight_decay)
        if args.model != 'mt' and args.model != 'baseline' and args.model != 'pi':
            student1_optimizer = torch.optim.Adam(
                student1_model.parameters(),
                args.lr,
                betas=(0.9, 0.999),
                weight_decay=args.weight_decay)
            student2_optimizer = torch.optim.Adam(
                student2_model.parameters(),
                args.lr,
                betas=(0.9, 0.999),
                weight_decay=args.weight_decay)

    elif args.optim == 'sgd':
        print('Using SGD optimizer')
        optimizer = torch.optim.SGD(student_model.parameters(),
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
        if args.model != 'mt' and args.model != 'baseline' and args.model != 'pi':
            student1_optimizer = torch.optim.SGD(
                student1_model.parameters(),
                args.lr,
                momentum=args.momentum,
                weight_decay=args.weight_decay)
            student2_optimizer = torch.optim.SGD(
                student2_model.parameters(),
                args.lr,
                momentum=args.momentum,
                weight_decay=args.weight_decay)

    # 保存点路径设置
    ckpt_dir = args.ckpt + '/' + args.dataset + '_' + str(
        args.label_num) + '_' + args.arch + '_' + args.model + '_' + args.optim
    ckpt_dir = ckpt_dir + '_e%d' % (args.epochs)

    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)
    cudnn.benchmark = True

    # 数据导入
    label_loader, unlabel_loader, val_loader, test_loader = \
        data_create.__dict__[args.dataset](label_num=args.label_num, boundary=args.boundary,
                                           batch_size=args.batch_size, num_workers=args.workers)

    # 损失函数
    criterion = nn.CrossEntropyLoss(reduction='sum').cuda()
    criterion_mse = nn.MSELoss(reduction='sum').cuda()
    criterion_kl = nn.KLDivLoss(reduction='sum').cuda()
    criterion_l1 = nn.L1Loss(reduction='sum').cuda()

    criterions = (criterion, criterion_mse, criterion_kl, criterion_l1)

    # 训练
    for epoch in range(args.start_epoch, args.epochs):
        # 修改学习率
        if args.optim == 'adam':
            print('Learning rate schedule for Adam')
            lr = adjust_learning_rate_adam(optimizer, epoch)
            if args.model != 'mt' and args.model != 'baseline' and args.model != 'pi':
                _ = adjust_learning_rate_adam(student1_optimizer, epoch)
                _ = adjust_learning_rate_adam(student2_optimizer, epoch)

        elif args.optim == 'sgd':
            print('Learning rate schedule for SGD')
            lr = adjust_learning_rate(optimizer, epoch)
            if args.model != 'mt' and args.model != 'baseline' and args.model != 'pi':
                _ = adjust_learning_rate(student1_optimizer, epoch)
                _ = adjust_learning_rate(student2_optimizer, epoch)

        # train for one epoch
        if args.model == 'baseline':
            print('Supervised Training')
            for i in range(
                    10
            ):  # baseline repeat 10 times since small number of training set
                anser_dict = train_sup(label_loader, student_model, criterions,
                                       optimizer, epoch, args)
                weight_cl = 0.0
                anser_dict['weight_cl'] = weight_cl
            # 添加损失,用于绘图
            loss_dict['loss_total'].append(anser_dict['total_loss'])
        elif args.model == 'pi':
            print('Pi model')
            anser_dict = train_pi(label_loader, unlabel_loader, student_model,
                                  criterions, optimizer, epoch, args)
            loss_dict['loss_class'].append(anser_dict['class_loss'])
            loss_dict['loss_cl'].append(anser_dict['pi_loss'])
            loss_dict['loss_total'].append(anser_dict['total_loss'])

        elif args.model == 'mt':
            print('Mean Teacher model')
            anser_dict = train_mt(label_loader, unlabel_loader, teacher_model,
                                  student_model, criterions, optimizer, epoch,
                                  args)
            loss_dict['loss_class'].append(anser_dict['class_loss'])
            loss_dict['loss_cl'].append(anser_dict['cl_loss'])
            loss_dict['loss_total'].append(anser_dict['total_loss'])

        elif args.model == 'mt+':
            print('Mean Teacher Plus Student model')
            # (学生1top1, 学生1分类损失,学生1一致性损失,教师top1,一致性权重)
            anser_dict = train_mtp(label_loader, unlabel_loader, student_model,
                                   student1_model, student2_model, criterions,
                                   student1_optimizer, student2_optimizer,
                                   epoch, args)
            loss_dict['loss_class'].append(anser_dict['ema_class_loss'])
            loss_dict['loss_class_1'].append(anser_dict['1_class_loss'])
            loss_dict['loss_cl_1'].append(anser_dict['1_cl_loss'])
            loss_dict['loss_total_1'].append(anser_dict['1_total_loss'])
            loss_dict['loss_class_2'].append(anser_dict['2_class_loss'])
            loss_dict['loss_cl_2'].append(anser_dict['2_cl_loss'])
            loss_dict['loss_total_2'].append(anser_dict['2_total_loss'])

        elif args.model == 'ds_mt':
            print('Dual Student with Mean Teacher Model')
            # (学生1top1, 学生1分类损失,学生1一致性损失,教师top1,一致性权重)
            anser_dict = train_mtp(label_loader,
                                   unlabel_loader,
                                   student_model,
                                   student1_model,
                                   student2_model,
                                   criterions,
                                   student1_optimizer,
                                   student2_optimizer,
                                   epoch,
                                   args,
                                   c_flag=True)
            loss_dict['loss_class'].append(anser_dict['ema_class_loss'])
            loss_dict['loss_class_1'].append(anser_dict['1_class_loss'])
            loss_dict['loss_cl_1'].append(anser_dict['1_cl_loss'])
            loss_dict['loss_total_1'].append(anser_dict['1_total_loss'])
            loss_dict['loss_class_2'].append(anser_dict['2_class_loss'])
            loss_dict['loss_cl_2'].append(anser_dict['2_cl_loss'])
            loss_dict['loss_total_2'].append(anser_dict['2_total_loss'])

        elif args.model == 'd-ds_mt':
            print('Dual Student with Double Mean Teacher Model')
            # (学生1top1, 学生1分类损失,学生1一致性损失,教师top1,一致性权重)
            anser_dict = train_mtp(label_loader, unlabel_loader, student_model,
                                   student1_model, student2_model, criterions,
                                   student1_optimizer, student2_optimizer,
                                   epoch, args)
            loss_dict['loss_class'].append(anser_dict['ema_class_loss'])
            loss_dict['loss_class_1'].append(anser_dict['1_class_loss'])
            loss_dict['loss_cl_1'].append(anser_dict['1_cl_loss'])
            loss_dict['loss_total_1'].append(anser_dict['1_total_loss'])
            loss_dict['loss_class_2'].append(anser_dict['2_class_loss'])
            loss_dict['loss_cl_2'].append(anser_dict['2_cl_loss'])
            loss_dict['loss_total_2'].append(anser_dict['2_total_loss'])

        elif args.model == 'ms':
            print('Multiple Student Model')
            # (学生1top1, 学生1分类损失,学生1一致性损失,教师top1,一致性权重)
            anser_dict = train_mtp(label_loader, unlabel_loader, student_model,
                                   student1_model, student2_model, criterions,
                                   student1_optimizer, student2_optimizer,
                                   epoch, args)
            loss_dict['loss_class'].append(anser_dict['ema_class_loss'])
            loss_dict['loss_class_1'].append(anser_dict['1_class_loss'])
            loss_dict['loss_cl_1'].append(anser_dict['1_cl_loss'])
            loss_dict['loss_total_1'].append(anser_dict['1_total_loss'])
            loss_dict['loss_class_2'].append(anser_dict['2_class_loss'])
            loss_dict['loss_cl_2'].append(anser_dict['2_cl_loss'])
            loss_dict['loss_total_2'].append(anser_dict['2_total_loss'])
        # elif args.model == 'pso_mt+':
        #     print('pso_mt+ model')
        #     # (学生1top1, 学生1分类损失,学生1一致性损失,教师top1,一致性权重)
        #     prec1_1_tr, loss_1_tr, loss_1_cl_tr, prec1_t_tr, weight_cl  = train_psomt_pul(
        #         label_loader, unlabel_loader, student_model, student1_model, student2_model,
        #         criterions, student1_optimizer, student2_optimizer, epoch, args, best_his_param)
        else:
            print("Not Implemented ", args.model)
            assert (False)

        # 教师的验证和测试
        if args.model == 'mt':
            prec1_t_val, loss_t_val = validate(val_loader, teacher_model,
                                               criterions, args, 'valid')
            prec1_t_test, loss_t_test = validate(test_loader, teacher_model,
                                                 criterions, args, 'test')
        else:
            prec1_t_val, loss_t_val = validate(val_loader, student_model,
                                               criterions, args, 'valid')
            prec1_t_test, loss_t_test = validate(test_loader, student_model,
                                                 criterions, args, 'test')

        loss_dict['val_loss'].append(loss_t_val)
        loss_dict['test_loss'].append(loss_t_test)

        # 添加训练结果,保存到checkpoint中
        if args.model == 'baseline' or args.model == 'pi':
            acc1_stu1_tr.append(anser_dict['top1'])
            losses_stu1_tr.append(anser_dict['total_loss'])
            acc1_t_val.append(prec1_t_val)
            losses_t_val.append(loss_t_val)
            acc1_t_test.append(prec1_t_test)
            losses_t_test.append(loss_t_test)
        elif args.model == 'mt':
            # 学生训练
            acc1_stu1_tr.append(anser_dict['top1_1'])
            acc1_t_tr.append(anser_dict['top1_t'])
            losses_stu1_tr.append(anser_dict['top1_t'])
            losses_stu1_cl_tr.append(anser_dict['total_loss'])
            # 验证
            acc1_t_val.append(prec1_t_val)
            losses_t_val.append(loss_t_val)
            # 测试
            acc1_t_test.append(prec1_t_test)
            losses_t_test.append(loss_t_test)
        else:
            # 学生1训练
            acc1_stu1_tr.append(anser_dict['top1_1'])
            acc1_stu2_tr.append(anser_dict['top1_2'])
            acc1_t_tr.append(anser_dict['top1_t'])
            losses_stu1_tr.append(anser_dict['1_total_loss'])
            losses_stu2_tr.append(anser_dict['2_total_loss'])
            losses_t_tr.append(anser_dict['ema_class_loss'])
            losses_stu1_cl_tr.append(anser_dict['1_cl_loss'])
            acc1_t_val.append(prec1_t_val)
            losses_t_val.append(loss_t_val)
            acc1_t_test.append(prec1_t_test)
            losses_t_test.append(loss_t_test)
        weights_cl.append(anser_dict['weight_cl'])
        learning_rate.append(lr)

        # remember best prec@1 and save checkpoint
        if args.model == 'baseline' or args.model == 'pi':
            # 根据测试准确率保存
            is_best = prec1_t_test > best_prec1
            if is_best:
                best_prec1 = prec1_t_test
            print("Best test precision: %.3f" % best_prec1)
            dict_checkpoint = {
                'epoch': epoch + 1,
                'loss_dict': loss_dict,
                'student_state_dict': student_model.state_dict(),
                'best_prec1': best_prec1,
                'acc1_tr': acc1_stu1_tr,
                'losses_tr': losses_stu1_tr,
                'acc1_val': acc1_t_val,
                'losses_val': loss_t_val,
                'acc1_test': acc1_t_test,
                'losses_test': losses_t_test,
                'weights_cl': weights_cl,
                'learning_rate': learning_rate,
            }
        elif args.model == 'mt':
            is_best = prec1_t_test > best_prec1
            if is_best:
                best_prec1 = prec1_t_test
            print("Best test precision: %.3f" % best_prec1)
            dict_checkpoint = {
                'epoch': epoch + 1,
                'loss_dict': loss_dict,
                'student_state_dict': student_model.state_dict(),
                'teacher_state_dict': teacher_model.state_dict(),
                'best_prec1': best_prec1,
                'acc1_1_tr': acc1_stu1_tr,
                'losses_1_tr': losses_stu1_tr,
                'acc1_t_tr': acc1_t_tr,
                'losses_t_tr': losses_t_tr,
                'acc1_val': acc1_t_val,
                'losses_val': loss_t_val,
                'acc1_test': acc1_t_test,
                'losses_test': losses_t_test,
                'weights_cl': weights_cl,
                'learning_rate': learning_rate,
            }
        else:
            is_best = prec1_t_test > best_prec1
            if is_best:
                best_prec1 = prec1_t_test
            print("Best test precision: %.3f" % best_prec1)
            dict_checkpoint = {
                'epoch': epoch + 1,
                'loss_dict': loss_dict,
                'teacher_state_dict': student_model.state_dict(),  # 教师模型参数
                'student1_state_dict': student1.state_dict(),  # 学生1模型参数
                'student2_state_dict': student2.state_dict(),  # 学生2模型参数
                'best_prec1': best_prec1,
                'acc1_1_tr': acc1_stu1_tr,
                'losses_1_tr': losses_stu1_tr,
                'acc1_2_tr': acc1_stu2_tr,
                'losses_2_tr': losses_stu2_tr,
                'acc1_t_tr': acc1_t_tr,
                'acc1_t_val': acc1_t_val,
                'loss_t_val': loss_t_val,
                'acc1_t_test': acc1_t_test,
                'loss_t_test': losses_t_test,
                'weights_cl': weights_cl,
                'learning_rate': learning_rate,
            }

        save_checkpoint(dict_checkpoint,
                        is_best,
                        args.arch.lower() + str(args.boundary),
                        dirname=ckpt_dir)

    # 回执模型损失图并保存
    save_loss(loss_dict['loss_class'],
              loss_dict['loss_cl'],
              loss_dict['loss_total'],
              name='model detail loss')
    save_loss(loss_dict['loss_class_1'],
              loss_dict['loss_cl_1'],
              loss_dict['loss_total_1'],
              name='Student1 loss')
    save_loss(loss_dict['loss_class_2'],
              loss_dict['loss_cl_2'],
              loss_dict['loss_total_2'],
              name='Student2 loss')
    save_loss(loss_dict['loss_total'],
              loss_dict['val_loss'],
              loss_dict['test_loss'],
              f_flag=True,
              name='train/val/test_loss')
def main():
    global args, best_prec1, best_test_prec1
    global acc1_tr, losses_tr
    global losses_cl_tr
    global acc1_val, losses_val, losses_et_val
    global acc1_test, losses_test, losses_et_test
    global weights_cl
    args = parser.parse_args()
    print(args)
    if args.dataset == 'svhn':
        drop_rate = 0.3
        widen_factor = 3
    else:
        drop_rate = 0.3
        widen_factor = 3

    # create model
    if args.arch == 'preresnet':
        print("Model: %s" % args.arch)
        model = preresnet_cifar.resnet(depth=32, num_classes=args.num_classes)
    elif args.arch == 'wideresnet':
        print("Model: %s" % args.arch)
        model = wideresnet.WideResNet(28,
                                      args.num_classes,
                                      widen_factor=widen_factor,
                                      dropRate=drop_rate,
                                      leakyRate=0.1)
    else:
        assert (False)

    if args.model == 'mt':
        import copy
        model_teacher = copy.deepcopy(model)
        model_teacher = torch.nn.DataParallel(model_teacher).cuda()

    model = torch.nn.DataParallel(model).cuda()
    print(model)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']

            model.load_state_dict(checkpoint['state_dict'])
            if args.model == 'mt':
                model_teacher.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    if args.optim == 'sgd' or args.optim == 'adam':
        pass
    else:
        print('Not Implemented Optimizer')
        assert (False)

    ckpt_dir = args.ckpt + '_' + args.dataset + '_' + args.arch + '_' + args.model + '_' + args.optim
    ckpt_dir = ckpt_dir + '_e%d' % (args.epochs)
    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)
    print(ckpt_dir)
    cudnn.benchmark = True

    # Data loading code
    if args.dataset == 'cifar10':
        dataloader = cifar.CIFAR10
        num_classes = 10
        data_dir = '/tmp/'

        normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                                         std=[0.2023, 0.1994, 0.2010])
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=2),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])

    elif args.dataset == 'cifar10_zca':
        dataloader = cifar_zca.CIFAR10
        num_classes = 10
        data_dir = 'cifar10_zca/cifar10_gcn_zca_v2.npz'

        # transform is implemented inside zca dataloader
        transform_train = transforms.Compose([
            transforms.ToTensor(),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
        ])

    elif args.dataset == 'svhn':
        dataloader = svhn.SVHN
        num_classes = 10
        data_dir = '/tmp/'

        normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                         std=[0.5, 0.5, 0.5])
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=2),
            transforms.ToTensor(),
            normalize,
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])

    labelset = dataloader(root=data_dir,
                          split='label',
                          download=True,
                          transform=transform_train,
                          boundary=args.boundary)
    unlabelset = dataloader(root=data_dir,
                            split='unlabel',
                            download=True,
                            transform=transform_train,
                            boundary=args.boundary)
    batch_size_label = args.batch_size // 2
    batch_size_unlabel = args.batch_size // 2
    if args.model == 'baseline': batch_size_label = args.batch_size

    label_loader = data.DataLoader(labelset,
                                   batch_size=batch_size_label,
                                   shuffle=True,
                                   num_workers=args.workers,
                                   pin_memory=True)
    label_iter = iter(label_loader)

    unlabel_loader = data.DataLoader(unlabelset,
                                     batch_size=batch_size_unlabel,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     pin_memory=True)
    unlabel_iter = iter(unlabel_loader)

    print("Batch size (label): ", batch_size_label)
    print("Batch size (unlabel): ", batch_size_unlabel)

    validset = dataloader(root=data_dir,
                          split='valid',
                          download=True,
                          transform=transform_test,
                          boundary=args.boundary)
    val_loader = data.DataLoader(validset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=args.workers,
                                 pin_memory=True)

    testset = dataloader(root=data_dir,
                         split='test',
                         download=True,
                         transform=transform_test)
    test_loader = data.DataLoader(testset,
                                  batch_size=args.batch_size,
                                  shuffle=False,
                                  num_workers=args.workers,
                                  pin_memory=True)

    # deifine loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss(size_average=False).cuda()
    criterion_mse = nn.MSELoss(size_average=False).cuda()
    criterion_kl = nn.KLDivLoss(size_average=False).cuda()
    criterion_l1 = nn.L1Loss(size_average=False).cuda()

    criterions = (criterion, criterion_mse, criterion_kl, criterion_l1)

    if args.optim == 'adam':
        print('Using Adam optimizer')
        optimizer = torch.optim.Adam(model.parameters(),
                                     args.lr,
                                     betas=(0.9, 0.999),
                                     weight_decay=args.weight_decay)
    elif args.optim == 'sgd':
        print('Using SGD optimizer')
        optimizer = torch.optim.SGD(model.parameters(),
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)

    for epoch in range(args.start_epoch, args.epochs):
        if args.optim == 'adam':
            print('Learning rate schedule for Adam')
            lr = adjust_learning_rate_adam(optimizer, epoch)
        elif args.optim == 'sgd':
            print('Learning rate schedule for SGD')
            lr = adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        if args.model == 'baseline':
            print('Supervised Training')
            for i in range(
                    10
            ):  #baseline repeat 10 times since small number of training set
                prec1_tr, loss_tr = train_sup(label_loader, model, criterions,
                                              optimizer, epoch, args)
                weight_cl = 0.0
        elif args.model == 'pi':
            print('Pi model')
            prec1_tr, loss_tr, loss_cl_tr, weight_cl = train_pi(
                label_loader, unlabel_loader, model, criterions, optimizer,
                epoch, args)
        elif args.model == 'mt':
            print('Mean Teacher model')
            prec1_tr, loss_tr, loss_cl_tr, prec1_t_tr, weight_cl = train_mt(
                label_loader, unlabel_loader, model, model_teacher, criterions,
                optimizer, epoch, args)
        else:
            print("Not Implemented ", args.model)
            assert (False)

        # evaluate on validation set
        prec1_val, loss_val = validate(val_loader, model, criterions, args,
                                       'valid')
        prec1_test, loss_test = validate(test_loader, model, criterions, args,
                                         'test')
        if args.model == 'mt':
            prec1_t_val, loss_t_val = validate(val_loader, model_teacher,
                                               criterions, args, 'valid')
            prec1_t_test, loss_t_test = validate(test_loader, model_teacher,
                                                 criterions, args, 'test')

        # append values
        acc1_tr.append(prec1_tr)
        losses_tr.append(loss_tr)
        acc1_val.append(prec1_val)
        losses_val.append(loss_val)
        acc1_test.append(prec1_test)
        losses_test.append(loss_test)
        if args.model != 'baseline':
            losses_cl_tr.append(loss_cl_tr)
        if args.model == 'mt':
            acc1_t_tr.append(prec1_t_tr)
            acc1_t_val.append(prec1_t_val)
            acc1_t_test.append(prec1_t_test)
        weights_cl.append(weight_cl)
        learning_rate.append(lr)

        # remember best prec@1 and save checkpoint
        if args.model == 'mt':
            is_best = prec1_t_val > best_prec1
            if is_best:
                best_test_prec1_t = prec1_t_test
                best_test_prec1 = prec1_test
            print("Best test precision: %.3f" % best_test_prec1_t)
            best_prec1 = max(prec1_t_val, best_prec1)
            dict_checkpoint = {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'best_test_prec1': best_test_prec1,
                'acc1_tr': acc1_tr,
                'losses_tr': losses_tr,
                'losses_cl_tr': losses_cl_tr,
                'acc1_val': acc1_val,
                'losses_val': losses_val,
                'acc1_test': acc1_test,
                'losses_test': losses_test,
                'acc1_t_tr': acc1_t_tr,
                'acc1_t_val': acc1_t_val,
                'acc1_t_test': acc1_t_test,
                'state_dict_teacher': model_teacher.state_dict(),
                'best_test_prec1_t': best_test_prec1_t,
                'weights_cl': weights_cl,
                'learning_rate': learning_rate,
            }

        else:
            is_best = prec1_val > best_prec1
            if is_best:
                best_test_prec1 = prec1_test
            print("Best test precision: %.3f" % best_test_prec1)
            best_prec1 = max(prec1_val, best_prec1)
            dict_checkpoint = {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'best_test_prec1': best_test_prec1,
                'acc1_tr': acc1_tr,
                'losses_tr': losses_tr,
                'losses_cl_tr': losses_cl_tr,
                'acc1_val': acc1_val,
                'losses_val': losses_val,
                'acc1_test': acc1_test,
                'losses_test': losses_test,
                'weights_cl': weights_cl,
                'learning_rate': learning_rate,
            }

        save_checkpoint(dict_checkpoint,
                        is_best,
                        args.arch.lower() + str(args.boundary),
                        dirname=ckpt_dir)