示例#1
0
def train():
    model, criteria_x, criteria_u = set_model()

    n_iters_per_epoch = n_imgs_per_epoch // batchsize
    dltrain_x, dltrain_u = get_train_loader(
        batchsize, n_iters_per_epoch, L=250, K=n_guesses
    )
    lb_guessor = LabelGuessor(model, T=temperature)
    mixuper = MixUp(mixup_alpha)

    ema = EMA(model, ema_alpha)
    optim = torch.optim.Adam(model.parameters(), lr=lr)

    n_iters_per_epoch = n_imgs_per_epoch // batchsize
    lam_u_epoch = float(lam_u) / n_epoches
    lam_u_once = lam_u_epoch / n_iters_per_epoch

    train_args = dict(
        model=model,
        criteria_x=criteria_x,
        criteria_u=criteria_u,
        optim=optim,
        ema=ema,
        wd = 1 - weight_decay * lr,
        dltrain_x=dltrain_x,
        dltrain_u=dltrain_u,
        lb_guessor=lb_guessor,
        mixuper=mixuper,
        lambda_u=0,
        lambda_u_once=lam_u_once,
    )
    best_acc = -1
    print('start to train')
    for e in range(n_epoches):
        model.train()
        print('epoch: {}'.format(e))
        train_args['lambda_u'] = e * lam_u_epoch
        train_one_epoch(**train_args)
        torch.cuda.empty_cache()

        acc = evaluate(ema)
        best_acc = acc if best_acc < acc else best_acc
        log_msg = [
            'epoch: {}'.format(e),
            'acc: {:.4f}'.format(acc),
            'best_acc: {:.4f}'.format(best_acc)]
        print(', '.join(log_msg))
示例#2
0
def main():
    parser = argparse.ArgumentParser(description=' FixMatch Training')
    parser.add_argument('--wresnet-k',
                        default=2,
                        type=int,
                        help='width factor of wide resnet')
    parser.add_argument('--wresnet-n',
                        default=28,
                        type=int,
                        help='depth of wide resnet')
    parser.add_argument('--dataset',
                        type=str,
                        default='CIFAR10',
                        help='number of classes in dataset')
    # parser.add_argument('--n-classes', type=int, default=100,
    #                     help='number of classes in dataset')
    parser.add_argument('--n-labeled',
                        type=int,
                        default=40,
                        help='number of labeled samples for training')
    parser.add_argument('--n-epoches',
                        type=int,
                        default=1024,
                        help='number of training epoches')
    parser.add_argument('--batchsize',
                        type=int,
                        default=40,
                        help='train batch size of labeled samples')
    parser.add_argument('--mu',
                        type=int,
                        default=7,
                        help='factor of train batch size of unlabeled samples')
    parser.add_argument('--thr',
                        type=float,
                        default=0.95,
                        help='pseudo label threshold')
    parser.add_argument('--n-imgs-per-epoch',
                        type=int,
                        default=64 * 1024,
                        help='number of training images for each epoch')
    parser.add_argument('--lam-u',
                        type=float,
                        default=1.,
                        help='coefficient of unlabeled loss')
    parser.add_argument('--ema-alpha',
                        type=float,
                        default=0.999,
                        help='decay rate for ema module')
    parser.add_argument('--lr',
                        type=float,
                        default=0.03,
                        help='learning rate for training')
    parser.add_argument('--weight-decay',
                        type=float,
                        default=5e-4,
                        help='weight decay')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.9,
                        help='momentum for optimizer')
    parser.add_argument('--seed',
                        type=int,
                        default=-1,
                        help='seed for random behaviors, no seed if negtive')
    parser.add_argument('--temperature',
                        type=float,
                        default=0.5,
                        help='temperature for loss function')
    args = parser.parse_args()

    # args.device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

    logger, writer = setup_default_logging(args)
    logger.info(dict(args._get_kwargs()))

    # global settings
    #  torch.multiprocessing.set_sharing_strategy('file_system')
    if args.seed > 0:
        torch.manual_seed(args.seed)
        random.seed(args.seed)
        np.random.seed(args.seed)
        # torch.backends.cudnn.deterministic = True

    n_iters_per_epoch = args.n_imgs_per_epoch // args.batchsize  # 1024
    n_iters_all = n_iters_per_epoch * args.n_epoches  # 1024 * 1024

    logger.info("***** Running training *****")
    logger.info(f"  Task = {args.dataset}@{args.n_labeled}")
    logger.info(f"  Num Epochs = {n_iters_per_epoch}")
    logger.info(f"  Batch size per GPU = {args.batchsize}")
    # logger.info(f"  Total train batch size = {args.batch_size * args.world_size}")
    logger.info(f"  Total optimization steps = {n_iters_all}")

    model, criteria_x, criteria_u, criteria_z = set_model(args)
    logger.info("Total params: {:.2f}M".format(
        sum(p.numel() for p in model.parameters()) / 1e6))

    dltrain_x, dltrain_u = get_train_loader(args.dataset,
                                            args.batchsize,
                                            args.mu,
                                            n_iters_per_epoch,
                                            L=args.n_labeled)
    dlval = get_val_loader(dataset=args.dataset, batch_size=64, num_workers=2)

    lb_guessor = LabelGuessor(thresh=args.thr)

    ema = EMA(model, args.ema_alpha)

    wd_params, non_wd_params = [], []
    for name, param in model.named_parameters():
        # if len(param.size()) == 1:
        if 'bn' in name:
            non_wd_params.append(
                param)  # bn.weight, bn.bias and classifier.bias
            # print(name)
        else:
            wd_params.append(param)
    param_list = [{
        'params': wd_params
    }, {
        'params': non_wd_params,
        'weight_decay': 0
    }]
    optim = torch.optim.SGD(param_list,
                            lr=args.lr,
                            weight_decay=args.weight_decay,
                            momentum=args.momentum,
                            nesterov=True)
    lr_schdlr = WarmupCosineLrScheduler(optim,
                                        max_iter=n_iters_all,
                                        warmup_iter=0)

    train_args = dict(model=model,
                      criteria_x=criteria_x,
                      criteria_u=criteria_u,
                      criteria_z=criteria_z,
                      optim=optim,
                      lr_schdlr=lr_schdlr,
                      ema=ema,
                      dltrain_x=dltrain_x,
                      dltrain_u=dltrain_u,
                      lb_guessor=lb_guessor,
                      lambda_u=args.lam_u,
                      n_iters=n_iters_per_epoch,
                      logger=logger)
    best_acc = -1
    best_epoch = 0
    logger.info('-----------start training--------------')
    for epoch in range(args.n_epoches):
        train_loss, loss_x, loss_u, loss_u_real, loss_simclr, mask_mean = train_one_epoch(
            epoch, **train_args)
        # torch.cuda.empty_cache()

        top1, top5, valid_loss = evaluate(ema, dlval, criteria_x)

        writer.add_scalars('train/1.loss', {
            'train': train_loss,
            'test': valid_loss
        }, epoch)
        writer.add_scalar('train/2.train_loss_x', loss_x, epoch)
        writer.add_scalar('train/3.train_loss_u', loss_u, epoch)
        writer.add_scalar('train/4.train_loss_u_real', loss_u_real, epoch)
        writer.add_scalar('train/4.train_loss_simclr', loss_simclr, epoch)
        writer.add_scalar('train/5.mask_mean', mask_mean, epoch)
        writer.add_scalars('test/1.test_acc', {
            'top1': top1,
            'top5': top5
        }, epoch)
        # writer.add_scalar('test/2.test_loss', loss, epoch)

        # best_acc = top1 if best_acc < top1 else best_acc
        if best_acc < top1:
            best_acc = top1
            best_epoch = epoch

        logger.info(
            "Epoch {}. Top1: {:.4f}. Top5: {:.4f}. best_acc: {:.4f} in epoch{}"
            .format(epoch, top1, top5, best_acc, best_epoch))

    writer.close()
示例#3
0
def train():
    n_iters_per_epoch = args.n_imgs_per_epoch // args.batchsize
    n_iters_all = n_iters_per_epoch * args.n_epochs #/ args.mu_c
    epsilon = 0.000001

    model, criteria_x, criteria_u = set_model()
    lb_guessor = LabelGuessor(thresh=args.thr)
    ema = EMA(model, args.ema_alpha)

    wd_params, non_wd_params = [], []
    for param in model.parameters():
        if len(param.size()) == 1:
            non_wd_params.append(param)
        else:
            wd_params.append(param)
    param_list = [{'params': wd_params}, {'params': non_wd_params, 'weight_decay': 0}]
    optim = torch.optim.SGD(param_list, lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum, nesterov=True)
    lr_schdlr = WarmupCosineLrScheduler(optim, max_iter=n_iters_all, warmup_iter=0)

    dltrain_x, dltrain_u, dltrain_all = get_train_loader(args.batchsize, args.mu, args.mu_c, n_iters_per_epoch, 
                                                         L=args.n_labeled, seed=args.seed)
    train_args = dict(
        model=model,
        criteria_x=criteria_x,
        criteria_u=criteria_u,
        optim=optim,
        lr_schdlr=lr_schdlr,
        ema=ema,
        dltrain_x=dltrain_x,
        dltrain_u=dltrain_u,
        dltrain_all=dltrain_all,
        lb_guessor=lb_guessor,
    )
    n_labeled = int(args.n_labeled / args.n_classes)
    best_acc, top1 = -1, -1
    results = {'top 1 acc': [], 'best_acc': []}
    
    b_schedule = [args.n_epochs/2, 3*args.n_epochs/4]
    if args.boot_schedule == 1:
        step = int(args.n_epochs/3)
        b_schedule = [step, 2*step]
    elif args.boot_schedule == 2:
        step = int(args.n_epochs/4)
        b_schedule = [step, 2*step, 3*step]
        
    for e in range(args.n_epochs):
        if args.bootstrap > 1 and (e in b_schedule):
            seed = 99
            n_labeled *= args.bootstrap
            name = sort_unlabeled(ema, n_labeled)
            print("Bootstrap at epoch ", e," Name = ",name)
            dltrain_x, dltrain_u, dltrain_all = get_train_loader(args.batchsize, args.mu, args.mu_c, n_iters_per_epoch, 
                                                                 L=10*n_labeled, seed=seed, name=name)
            train_args = dict(
                model=model,
                criteria_x=criteria_x,
                criteria_u=criteria_u,
                optim=optim,
                lr_schdlr=lr_schdlr,
                ema=ema,
                dltrain_x=dltrain_x,
                dltrain_u=dltrain_u,
                dltrain_all=dltrain_all,
                lb_guessor=lb_guessor,
            )

        model.train()
        train_one_epoch(**train_args)
        torch.cuda.empty_cache()

        if args.test == 0 or args.lam_clr < epsilon:
            top1 = evaluate(ema) * 100
        elif args.test == 1:
            memory_data = utils.CIFAR10Pair(root='dataset', train=True, transform=utils.test_transform, download=False)
            memory_data_loader = DataLoader(memory_data, batch_size=args.batchsize, shuffle=False, num_workers=16, pin_memory=True)
            test_data = utils.CIFAR10Pair(root='dataset', train=False, transform=utils.test_transform, download=False)
            test_data_loader = DataLoader(test_data, batch_size=args.batchsize, shuffle=False, num_workers=16, pin_memory=True)
            c = len(memory_data.classes) #10
            top1 = test(model, memory_data_loader, test_data_loader, c, e)
            
        best_acc = top1 if best_acc < top1 else best_acc

        results['top 1 acc'].append('{:.4f}'.format(top1))
        results['best_acc'].append('{:.4f}'.format(best_acc))
        data_frame = pd.DataFrame(data=results)
        data_frame.to_csv(result_dir + '/' + save_name_pre + '.accuracy.csv', index_label='epoch')

        log_msg = [
            'epoch: {}'.format(e + 1),
            'top 1 acc: {:.4f}'.format(top1),
            'best_acc: {:.4f}'.format(best_acc)]
        print(', '.join(log_msg))
def main():
    parser = argparse.ArgumentParser(description=' FixMatch Training')
    parser.add_argument('--wresnet-k',
                        default=2,
                        type=int,
                        help='width factor of wide resnet')
    parser.add_argument('--wresnet-n',
                        default=28,
                        type=int,
                        help='depth of wide resnet')
    parser.add_argument('--dataset',
                        type=str,
                        default='CIFAR10',
                        help='number of classes in dataset')
    # parser.add_argument('--n-classes', type=int, default=100,
    #                     help='number of classes in dataset')
    parser.add_argument('--n-labeled',
                        type=int,
                        default=40,
                        help='number of labeled samples for training')
    parser.add_argument('--n-epoches',
                        type=int,
                        default=1024,
                        help='number of training epoches')
    parser.add_argument('--batchsize',
                        type=int,
                        default=40,
                        help='train batch size of labeled samples')
    parser.add_argument('--mu',
                        type=int,
                        default=7,
                        help='factor of train batch size of unlabeled samples')
    parser.add_argument('--thr',
                        type=float,
                        default=0.95,
                        help='pseudo label threshold')
    parser.add_argument('--n-imgs-per-epoch',
                        type=int,
                        default=64 * 1024,
                        help='number of training images for each epoch')
    parser.add_argument('--lam-u',
                        type=float,
                        default=1.,
                        help='coefficient of unlabeled loss')
    parser.add_argument('--lam-s',
                        type=float,
                        default=0.2,
                        help='coefficient of unlabeled loss SimCLR')
    parser.add_argument('--ema-alpha',
                        type=float,
                        default=0.999,
                        help='decay rate for ema module')
    parser.add_argument('--lr',
                        type=float,
                        default=0.03,
                        help='learning rate for training')
    parser.add_argument('--weight-decay',
                        type=float,
                        default=5e-4,
                        help='weight decay')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.9,
                        help='momentum for optimizer')
    parser.add_argument('--seed',
                        type=int,
                        default=-1,
                        help='seed for random behaviors, no seed if negtive')
    parser.add_argument('--temperature',
                        type=float,
                        default=0.5,
                        help='temperature for loss function')
    args = parser.parse_args()

    # args.device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

    logger, writer = setup_default_logging(args)
    logger.info(dict(args._get_kwargs()))

    # global settings
    #  torch.multiprocessing.set_sharing_strategy('file_system')
    if args.seed > 0:
        torch.manual_seed(args.seed)
        random.seed(args.seed)
        np.random.seed(args.seed)
        # torch.backends.cudnn.deterministic = True

    n_iters_per_epoch = args.n_imgs_per_epoch // args.batchsize  # 1024
    n_iters_all = n_iters_per_epoch * args.n_epoches  # 1024 * 1024

    logger.info("***** Running training *****")
    logger.info(f"  Task = {args.dataset}@{args.n_labeled}")
    logger.info(f"  Num Epochs = {n_iters_per_epoch}")
    logger.info(f"  Batch size per GPU = {args.batchsize}")
    # logger.info(f"  Total train batch size = {args.batch_size * args.world_size}")
    logger.info(f"  Total optimization steps = {n_iters_all}")

    model, criteria_x, criteria_u, criteria_z = set_model(args)

    logger.info("Total params: {:.2f}M".format(
        sum(p.numel() for p in model.parameters()) / 1e6))

    dltrain_x, dltrain_u, dltrain_f = get_train_loader_mix(args.dataset,
                                                           args.batchsize,
                                                           args.mu,
                                                           n_iters_per_epoch,
                                                           L=args.n_labeled)
    dlval = get_val_loader(dataset=args.dataset, batch_size=64, num_workers=2)

    lb_guessor = LabelGuessor(thresh=args.thr)

    ema = EMA(model, args.ema_alpha)

    wd_params, non_wd_params = [], []
    for name, param in model.named_parameters():
        # if len(param.size()) == 1:
        if 'bn' in name:
            non_wd_params.append(
                param)  # bn.weight, bn.bias and classifier.bias
            # print(name)
        else:
            wd_params.append(param)

    param_list = [{
        'params': wd_params
    }, {
        'params': non_wd_params,
        'weight_decay': 0
    }]

    optim_fix = torch.optim.SGD(param_list,
                                lr=args.lr,
                                weight_decay=args.weight_decay,
                                momentum=args.momentum,
                                nesterov=True)
    lr_schdlr_fix = WarmupCosineLrScheduler(optim_fix,
                                            max_iter=n_iters_all,
                                            warmup_iter=0)

    train_args = dict(model=model,
                      criteria_x=criteria_x,
                      criteria_u=criteria_u,
                      criteria_z=criteria_z,
                      optim=optim_fix,
                      lr_schdlr=lr_schdlr_fix,
                      ema=ema,
                      dltrain_x=dltrain_x,
                      dltrain_u=dltrain_u,
                      dltrain_f=dltrain_f,
                      lb_guessor=lb_guessor,
                      lambda_u=args.lam_u,
                      lambda_s=args.lam_s,
                      n_iters=n_iters_per_epoch,
                      logger=logger,
                      bt=args.batchsize,
                      mu=args.mu)

    # # TRAINING PARAMETERS FOR SIMCLR
    # param_list = [
    #     {'params': wd_params}, {'params': non_wd_params, 'weight_decay': 0}]
    #
    # optim_simclr = torch.optim.SGD(param_list, lr=0.5, weight_decay=args.weight_decay,
    #                                momentum=args.momentum, nesterov=False)
    #
    # lr_schdlr_simclr = WarmupCosineLrScheduler(
    #     optim_simclr, max_iter=n_iters_all, warmup_iter=0
    # )
    #
    # train_args_simclr = dict(
    #     model=model,
    #     criteria_z=criteria_z,
    #     optim=optim_simclr,
    #     lr_schdlr=lr_schdlr_simclr,
    #     ema=ema,
    #     dltrain_f=dltrain_f,
    #     lambda_s=args.lam_s,
    #     n_iters=n_iters_per_epoch,
    #     logger=logger,
    #     bt=args.batchsize,
    #     mu=args.mu
    # )

    # # TRAINING PARAMETERS FOR IIC
    # param_list = [
    #     {'params': wd_params}, {'params': non_wd_params, 'weight_decay': 0}]
    #
    # optim_iic = torch.optim.Adam(param_list, lr=1e-4, weight_decay=args.weight_decay)
    #
    # lr_schdlr_iic = WarmupCosineLrScheduler(
    #     optim_iic, max_iter=n_iters_all, warmup_iter=0
    # )
    #
    # train_args_iic = dict(
    #     model=model,
    #     optim=optim_iic,
    #     lr_schdlr=lr_schdlr_iic,
    #     ema=ema,
    #     dltrain_f=dltrain_f,
    #     n_iters=n_iters_per_epoch,
    #     logger=logger,
    #     bt=args.batchsize,
    #     mu=args.mu
    # )
    #

    best_acc = -1
    best_epoch = 0
    logger.info('-----------start training--------------')

    for epoch in range(args.n_epoches):
        # guardar accuracy de modelo preentrenado hasta espacio h (SALIDA DE BACKBONE)
        top1, top5, valid_loss = evaluate_linear_Clf(ema, dltrain_x, dlval,
                                                     criteria_x)
        writer.add_scalars('test/1.test_linear_acc', {
            'top1': top1,
            'top5': top5
        }, epoch)

        logger.info("Epoch {}. on h space Top1: {:.4f}. Top5: {:.4f}.".format(
            epoch, top1, top5))

        if epoch < -500:
            # # FASE DE ENTRENAMIENTO NO SUPERVISADO
            # entrenar feature representation simclr
            # train_loss, loss_simclr, model_ = train_one_epoch_simclr(epoch, **train_args_simclr)
            # writer.add_scalar('train/4.train_loss_simclr', loss_simclr, epoch)

            # entrenar iic
            # train_loss, loss_iic, model_ = train_one_epoch_iic(epoch, **train_args_iic)
            # writer.add_scalar('train/4.train_loss_iic', loss_iic, epoch)
            # evaluate_Clf(model_, dltrain_f, dlval, criteria_x)

            top1, top5, valid_loss = evaluate_linear_Clf(
                ema, dltrain_x, dlval, criteria_x)
            # # GUARDAR MODELO ENTRENADO DE FORMA NO SUPERVISADA
            # if epoch == 497:
            #     # save model
            #     name = 'simclr_trained_good_h2.pt'
            #     torch.save(model_.state_dict(), name)
            #     logger.info('model saved')

        else:
            # ENTRENAMIENTO SEMI-SUPERVISADO
            train_loss, loss_x, loss_u, loss_u_real, mask_mean, loss_simclr = train_one_epoch(
                epoch, **train_args)
            top1, top5, valid_loss = evaluate(ema, dlval, criteria_x)

            writer.add_scalar('train/4.train_loss_simclr', loss_simclr, epoch)
            writer.add_scalar('train/2.train_loss_x', loss_x, epoch)
            writer.add_scalar('train/3.train_loss_u', loss_u, epoch)
            writer.add_scalar('train/4.train_loss_u_real', loss_u_real, epoch)
            writer.add_scalar('train/5.mask_mean', mask_mean, epoch)

        writer.add_scalars('train/1.loss', {
            'train': train_loss,
            'test': valid_loss
        }, epoch)
        writer.add_scalars('test/1.test_acc', {
            'top1': top1,
            'top5': top5
        }, epoch)
        # writer.add_scalar('test/2.test_loss', loss, epoch)

        # best_acc = top1 if best_acc < top1 else best_acc
        if best_acc < top1:
            best_acc = top1
            best_epoch = epoch

        logger.info(
            "Epoch {}. Top1: {:.4f}. Top5: {:.4f}. best_acc: {:.4f} in epoch{}"
            .format(epoch, top1, top5, best_acc, best_epoch))

    writer.close()
示例#5
0
    def __init__(self, pwd):
        self.model = WideResnet(cfg.n_classes,
                                k=cfg.wresnet_k,
                                n=cfg.wresnet_n,
                                batchsize=cfg.batch_size)
        self.model = self.model.cuda()

        self.unlabeled_trainloader, self.val_loader = loading_data()
        self.labeled_trainloader = update_loading(epoch=0)

        wd_params, non_wd_params = [], []
        for param in self.model.parameters():
            if len(param.size()) == 1:
                non_wd_params.append(param)
            else:
                wd_params.append(param)
        param_list = [{
            'params': wd_params
        }, {
            'params': non_wd_params,
            'weight_decay': 0
        }]
        self.optimizer = torch.optim.SGD(param_list,
                                         lr=cfg.lr,
                                         weight_decay=cfg.weight_decay,
                                         momentum=cfg.momentum,
                                         nesterov=True)

        self.n_iters_per_epoch = cfg.n_imgs_per_epoch // cfg.batch_size

        self.lr_schdlr = WarmupCosineLrScheduler(
            self.optimizer,
            max_iter=self.n_iters_per_epoch * cfg.n_epoches,
            warmup_iter=0)

        self.lb_guessor = LabelGuessor(args=cfg)

        self.train_record = {
            'best_acc1': 0,
            'best_model_name': '',
            'last_model_name': ''
        }
        self.cross_entropy = nn.CrossEntropyLoss().cuda()
        self.i_tb = 0
        self.epoch = 0
        self.exp_name = cfg.exp_name
        self.exp_path = cfg.exp_path
        if cfg.resume:
            print('Loaded resume weights for WideResnet')

            latest_state = torch.load(cfg.resume_model)
            self.model.load_state_dict(latest_state['net'])
            self.optimizer.load_state_dict(latest_state['optimizer'])
            self.lr_schdlr.load_state_dict(latest_state['scheduler'])
            self.epoch = latest_state['epoch'] + 1
            self.i_tb = latest_state['i_tb']
            self.train_record = latest_state['train_record']
            self.exp_path = latest_state['exp_path']
            self.exp_name = latest_state['exp_name']

        self.ema = EMA(self.model, cfg.ema_alpha)
        self.writer, self.log_txt = logger(
            cfg.exp_path, cfg.exp_name, pwd,
            ['exp', 'dataset', 'pretrained', 'pre_trained'])
示例#6
0
class Trainer():
    def __init__(self, pwd):
        self.model = WideResnet(cfg.n_classes,
                                k=cfg.wresnet_k,
                                n=cfg.wresnet_n,
                                batchsize=cfg.batch_size)
        self.model = self.model.cuda()

        self.unlabeled_trainloader, self.val_loader = loading_data()
        self.labeled_trainloader = update_loading(epoch=0)

        wd_params, non_wd_params = [], []
        for param in self.model.parameters():
            if len(param.size()) == 1:
                non_wd_params.append(param)
            else:
                wd_params.append(param)
        param_list = [{
            'params': wd_params
        }, {
            'params': non_wd_params,
            'weight_decay': 0
        }]
        self.optimizer = torch.optim.SGD(param_list,
                                         lr=cfg.lr,
                                         weight_decay=cfg.weight_decay,
                                         momentum=cfg.momentum,
                                         nesterov=True)

        self.n_iters_per_epoch = cfg.n_imgs_per_epoch // cfg.batch_size

        self.lr_schdlr = WarmupCosineLrScheduler(
            self.optimizer,
            max_iter=self.n_iters_per_epoch * cfg.n_epoches,
            warmup_iter=0)

        self.lb_guessor = LabelGuessor(args=cfg)

        self.train_record = {
            'best_acc1': 0,
            'best_model_name': '',
            'last_model_name': ''
        }
        self.cross_entropy = nn.CrossEntropyLoss().cuda()
        self.i_tb = 0
        self.epoch = 0
        self.exp_name = cfg.exp_name
        self.exp_path = cfg.exp_path
        if cfg.resume:
            print('Loaded resume weights for WideResnet')

            latest_state = torch.load(cfg.resume_model)
            self.model.load_state_dict(latest_state['net'])
            self.optimizer.load_state_dict(latest_state['optimizer'])
            self.lr_schdlr.load_state_dict(latest_state['scheduler'])
            self.epoch = latest_state['epoch'] + 1
            self.i_tb = latest_state['i_tb']
            self.train_record = latest_state['train_record']
            self.exp_path = latest_state['exp_path']
            self.exp_name = latest_state['exp_name']

        self.ema = EMA(self.model, cfg.ema_alpha)
        self.writer, self.log_txt = logger(
            cfg.exp_path, cfg.exp_name, pwd,
            ['exp', 'dataset', 'pretrained', 'pre_trained'])

    def forward(self):

        print('start to train')
        for epoch in range(self.epoch, cfg.n_epoches):
            self.epoch = epoch

            print(('=' * 50 + 'epoch: {}' + '=' * 50).format(self.epoch + 1))
            self.train()
            torch.cuda.empty_cache()
            self.evaluate(self.ema)

    def train(self):
        if self.epoch > cfg.start_add_samples_epoch:
            indexs, pre_targets = self.lb_guessor.label_generator.new_data()
            self.labeled_trainloader = update_loading(
                copy.deepcopy(indexs), copy.deepcopy(pre_targets), self.epoch)
            self.lb_guessor.init_for_add_sample(self.epoch,
                                                cfg.start_add_samples_epoch)

        self.model.train()
        Loss,Loss_L, Loss_U, Loss_U_Real, Loss_MI=AverageMeter(),AverageMeter(),AverageMeter(),\
                                                  AverageMeter(),AverageMeter()
        Correct_Num, Valid_Num = AverageMeter(), AverageMeter()

        st = time.time()
        l_set, u_set = iter(self.labeled_trainloader), iter(
            self.unlabeled_trainloader)

        for it in range(self.n_iters_per_epoch):
            (img, img_l_weak, img_l_strong), lbs_l = next(l_set)

            (img_u, img_u_weak,
             img_u_strong), lbs_u_real, index_u = next(u_set)

            img_l_weak, img_l_strong, lbs_l = img_l_weak.cuda(
            ), img_l_strong.cuda(), lbs_l.cuda()
            img_u, img_u_weak, img_u_strong = img_u.cuda(), img_u_weak.cuda(
            ), img_u_strong.cuda()

            lbs_u, valid_u = self.lb_guessor(self.model, img_l_weak,
                                             img_u_weak, lbs_l, index_u)

            n_u = img_u_strong.size(0)

            img_cat = torch.cat([img_l_weak, img_u, img_u_weak, img_u_strong],
                                dim=0).detach()

            _, __, pred_l, pred_u = self.model(img_cat)

            pred_u_o, pred_u_w, pred_u_s = pred_u[:n_u], pred_u[
                n_u:2 * n_u], pred_u[2 * n_u:]

            #=====================cross-entropy loss for labeled data==============
            loss_l = self.cross_entropy(pred_l, lbs_l)

            # =====================T-MI loss for unlabeled data==============
            if self.epoch >= 20:
                T_MI_loss = Triplet_MI_loss(pred_u_o, pred_u_w, pred_u_s)
            else:
                T_MI_loss = torch.tensor(0)

            # =====================cross-entropy loss for unlabeled data==============
            if lbs_u.size(0) > 0 and self.epoch >= 2:
                pred_u_s = pred_u_s[valid_u]
                loss_u = self.cross_entropy(pred_u_s, lbs_u)

                with torch.no_grad():
                    lbs_u_real = lbs_u_real[valid_u].cuda()
                    valid_num = lbs_u_real.size(0)
                    corr_lb = (lbs_u_real == lbs_u)
                    loss_u_real = F.cross_entropy(pred_u_s, lbs_u_real)
            else:
                loss_u = torch.tensor(0)
                loss_u_real = torch.tensor(0)
                corr_lb = torch.tensor(0)
                valid_num = 0

            loss = loss_l + cfg.lam_u * loss_u + 0.1 * T_MI_loss

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            self.ema.update_params()
            self.lr_schdlr.step()

            Loss.update(loss.item())
            Loss_L.update(loss_l.item())
            Loss_U.update(loss_u.item())
            Loss_U_Real.update(loss_u_real.item())
            Loss_MI.update(T_MI_loss.item())
            Correct_Num.update(corr_lb.sum().item())
            Valid_Num.update(valid_num)

            if (it + 1) % 256 == 0:
                self.i_tb += 1
                self.writer.add_scalar('loss_u', Loss_U.avg, self.i_tb)
                self.writer.add_scalar('loss_MI', Loss_MI.avg, self.i_tb)
                ed = time.time()
                t = ed - st
                lr_log = [pg['lr'] for pg in self.optimizer.param_groups]
                lr_log = sum(lr_log) / len(lr_log)
                msg = ', '.join([
                    '     [iter: {}',
                    'loss: {:.3f}',
                    'loss_l: {:.4f}',
                    'loss_u: {:.4f}',
                    'loss_u_real: {:.4f}',
                    'loss_MI: {:.4f}',
                    'correct: {}/{}',
                    'lr: {:.4f}',
                    'time: {:.2f}]',
                ]).format(it + 1, Loss.avg, Loss_L.avg, Loss_U.avg,
                          Loss_U_Real.avg, Loss_MI.avg, int(Correct_Num.avg),
                          int(Valid_Num.avg), lr_log, t)
                st = ed
                print(msg)

        self.ema.update_buffer()
        self.writer.add_scalar(
            'acc_overall', Correct_Num.sum / (cfg.n_imgs_per_epoch * cfg.mu),
            self.epoch + 1)
        self.writer.add_scalar('acc_in_labeled',
                               Correct_Num.sum / (Valid_Num.sum + 1e-10),
                               self.epoch + 1)

    def evaluate(self, ema):
        ema.apply_shadow()
        ema.model.eval()
        ema.model.cuda()

        matches = []
        for ims, lbs in self.val_loader:
            ims = ims.cuda()
            lbs = lbs.cuda()
            with torch.no_grad():
                __, preds = ema.model(ims, mode='val')
                scores = torch.softmax(preds, dim=1)
                _, preds = torch.max(scores, dim=1)
                match = lbs == preds
                matches.append(match)
        matches = torch.cat(matches, dim=0).float()
        acc = torch.mean(matches)

        self.writer.add_scalar('val_acc', acc, self.epoch)
        self.train_record = update_model(ema.model, self.optimizer,
                                         self.lr_schdlr, self.epoch, self.i_tb,
                                         self.exp_path, self.exp_name, acc,
                                         self.train_record)
        print_summary(cfg.exp_name, acc, self.train_record)
        ema.restore()
示例#7
0
def train():
    n_iters_per_epoch = args.n_imgs_per_epoch // args.batchsize
    n_iters_all = n_iters_per_epoch * args.n_epochs

    model, criteria_x, criteria_u = set_model()

    dltrain_x, dltrain_u = get_train_loader(args.batchsize,
                                            args.mu,
                                            n_iters_per_epoch,
                                            L=args.n_labeled,
                                            seed=args.seed)
    lb_guessor = LabelGuessor(thresh=args.thr)

    ema = EMA(model, args.ema_alpha)

    wd_params, non_wd_params = [], []
    for param in model.parameters():
        if len(param.size()) == 1:
            non_wd_params.append(param)
        else:
            wd_params.append(param)
    param_list = [{
        'params': wd_params
    }, {
        'params': non_wd_params,
        'weight_decay': 0
    }]
    optim = torch.optim.SGD(param_list,
                            lr=args.lr,
                            weight_decay=args.weight_decay,
                            momentum=args.momentum,
                            nesterov=True)
    lr_schdlr = WarmupCosineLrScheduler(optim,
                                        max_iter=n_iters_all,
                                        warmup_iter=0)

    train_args = dict(
        model=model,
        criteria_x=criteria_x,
        criteria_u=criteria_u,
        optim=optim,
        lr_schdlr=lr_schdlr,
        ema=ema,
        dltrain_x=dltrain_x,
        dltrain_u=dltrain_u,
        lb_guessor=lb_guessor,
        lambda_u=args.lam_u,
        lambda_c=args.lam_c,
        n_iters=n_iters_per_epoch,
    )
    best_acc = -1
    print('start to train')
    for e in range(args.n_epochs):
        model.train()
        print('epoch: {}'.format(e + 1))
        train_one_epoch(**train_args)
        torch.cuda.empty_cache()

        acc = evaluate(ema)
        best_acc = acc if best_acc < acc else best_acc
        log_msg = [
            'epoch: {}'.format(e), 'acc: {:.4f}'.format(acc),
            'best_acc: {:.4f}'.format(best_acc)
        ]
        print(', '.join(log_msg))

    sort_unlabeled(ema)