Example #1
0
def main():

    args = parse_option()

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    # set the data loader
    data_folder = os.path.join(args.data_folder, 'train')
    val_folder = os.path.join(args.data_folder, 'val')

    crop_padding = 32
    image_size = 224
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    normalize = transforms.Normalize(mean=mean, std=std)

    if args.aug == 'NULL' and args.dataset == 'imagenet':
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(image_size, scale=(args.crop, 1.)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    elif args.aug == 'CJ':
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(image_size, scale=(args.crop, 1.)),
            transforms.RandomGrayscale(p=0.2),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    # elif args.aug == 'NULL' and args.dataset == 'cifar':
    #     train_transform = transforms.Compose([
    #         transforms.RandomResizedCrop(size=32, scale=(0.2, 1.)),
    #         transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
    #         transforms.RandomGrayscale(p=0.2),
    #         transforms.RandomHorizontalFlip(p=0.5),
    #         transforms.ToTensor(),
    #         transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    #     ])
    #
    #     test_transform = transforms.Compose([
    #         transforms.ToTensor(),
    #         transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    #     ])
    elif args.aug == 'simple' and args.dataset == 'imagenet':
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(image_size, scale=(args.crop, 1.)),
            transforms.RandomHorizontalFlip(),
            get_color_distortion(1.0),
            transforms.ToTensor(),
            normalize,
        ])

        # TODO: Currently follow CMC
        test_transform = transforms.Compose([
            transforms.Resize(image_size + crop_padding),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            normalize,
        ])
    elif args.aug == 'simple' and args.dataset == 'cifar':
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(size=32),
            transforms.RandomHorizontalFlip(p=0.5),
            get_color_distortion(0.5),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        test_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

    else:
        raise NotImplemented('augmentation not supported: {}'.format(args.aug))

    # Get Datasets
    if args.dataset == "imagenet":
        train_dataset = ImageFolderInstance(data_folder,
                                            transform=train_transform,
                                            two_crop=args.moco)
        print(len(train_dataset))
        train_sampler = None
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=(train_sampler is None),
            num_workers=args.num_workers,
            pin_memory=True,
            sampler=train_sampler)

        test_dataset = datasets.ImageFolder(val_folder,
                                            transforms=test_transform)

        test_loader = torch.utils.data.DataLoader(test_dataset,
                                                  batch_size=256,
                                                  shuffle=False,
                                                  num_workers=args.num_workers,
                                                  pin_memory=True)

    elif args.dataset == 'cifar':
        # cifar-10 dataset
        if args.contrastive_model == 'simclr':
            train_dataset = CIFAR10Instance_double(root='./data',
                                                   train=True,
                                                   download=True,
                                                   transform=train_transform,
                                                   double=True)
        else:
            train_dataset = CIFAR10Instance(root='./data',
                                            train=True,
                                            download=True,
                                            transform=train_transform)
        train_sampler = None
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=(train_sampler is None),
            num_workers=args.num_workers,
            pin_memory=True,
            sampler=train_sampler,
            drop_last=True)

        test_dataset = CIFAR10Instance(root='./data',
                                       train=False,
                                       download=True,
                                       transform=test_transform)
        test_loader = torch.utils.data.DataLoader(test_dataset,
                                                  batch_size=100,
                                                  shuffle=False,
                                                  num_workers=args.num_workers)

    # create model and optimizer
    n_data = len(train_dataset)

    if args.model == 'resnet50':
        model = InsResNet50()
        if args.contrastive_model == 'moco':
            model_ema = InsResNet50()
    elif args.model == 'resnet50x2':
        model = InsResNet50(width=2)
        if args.contrastive_model == 'moco':
            model_ema = InsResNet50(width=2)
    elif args.model == 'resnet50x4':
        model = InsResNet50(width=4)
        if args.contrastive_model == 'moco':
            model_ema = InsResNet50(width=4)
    elif args.model == 'resnet50_cifar':
        model = InsResNet50_cifar()
        if args.contrastive_model == 'moco':
            model_ema = InsResNet50_cifar()
    else:
        raise NotImplementedError('model not supported {}'.format(args.model))

    # copy weights from `model' to `model_ema'
    if args.contrastive_model == 'moco':
        moment_update(model, model_ema, 0)

    # set the contrast memory and criterion
    if args.contrastive_model == 'moco':
        contrast = MemoryMoCo(128, n_data, args.nce_k, args.nce_t,
                              args.softmax).cuda(args.gpu)
    elif args.contrastive_model == 'simclr':
        contrast = None
    else:
        contrast = MemoryInsDis(128, n_data, args.nce_k, args.nce_t,
                                args.nce_m, args.softmax).cuda(args.gpu)

    if args.softmax:
        criterion = NCESoftmaxLoss()
    elif args.contrastive_model == 'simclr':
        criterion = BatchCriterion(1, args.nce_t, args.batch_size)
    else:
        criterion = NCECriterion(n_data)
    criterion = criterion.cuda(args.gpu)

    model = model.cuda()
    if args.contrastive_model == 'moco':
        model_ema = model_ema.cuda()

    # Exclude BN and bias if needed
    weight_decay = args.weight_decay
    if weight_decay and args.filter_weight_decay:
        parameters = add_weight_decay(model, weight_decay,
                                      args.filter_weight_decay)
        weight_decay = 0.
    else:
        parameters = model.parameters()

    optimizer = torch.optim.SGD(parameters,
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=weight_decay)
    cudnn.benchmark = True

    if args.amp:
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.opt_level)
        if args.contrastive_model == 'moco':
            optimizer_ema = torch.optim.SGD(model_ema.parameters(),
                                            lr=0,
                                            momentum=0,
                                            weight_decay=0)
            model_ema, optimizer_ema = amp.initialize(model_ema,
                                                      optimizer_ema,
                                                      opt_level=args.opt_level)

    if args.LARS:
        optimizer = LARS(optimizer=optimizer, eps=1e-8, trust_coef=0.001)

    # optionally resume from a checkpoint
    args.start_epoch = 0
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cpu')
            # checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch'] + 1
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            if contrast:
                contrast.load_state_dict(checkpoint['contrast'])
            if args.contrastive_model == 'moco':
                model_ema.load_state_dict(checkpoint['model_ema'])

            if args.amp and checkpoint['opt'].amp:
                print('==> resuming amp state_dict')
                amp.load_state_dict(checkpoint['amp'])

            print("=> loaded successfully '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            del checkpoint
            torch.cuda.empty_cache()
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # tensorboard
    logger = tb_logger.Logger(logdir=args.tb_folder, flush_secs=2)

    # routine
    for epoch in range(args.start_epoch, args.epochs + 1):

        print("==> training...")

        time1 = time.time()
        if args.contrastive_model == 'moco':
            loss, prob = train_moco(epoch, train_loader, model, model_ema,
                                    contrast, criterion, optimizer, args)
        elif args.contrastive_model == 'simclr':
            print("Train using simclr")
            loss, prob = train_simclr(epoch, train_loader, model, criterion,
                                      optimizer, args)
        else:
            print("Train using InsDis")
            loss, prob = train_ins(epoch, train_loader, model, contrast,
                                   criterion, optimizer, args)
        time2 = time.time()
        print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))

        # tensorboard logger
        logger.log_value('ins_loss', loss, epoch)
        logger.log_value('ins_prob', prob, epoch)
        logger.log_value('learning_rate', optimizer.param_groups[0]['lr'],
                         epoch)

        test_epoch = 2
        if epoch % test_epoch == 0:
            model.eval()

            if args.contrastive_model == 'moco':
                model_ema.eval()

            print('----------Evaluation---------')
            start = time.time()

            if args.dataset == 'cifar':
                acc = kNN(epoch,
                          model,
                          train_loader,
                          test_loader,
                          200,
                          args.nce_t,
                          n_data,
                          low_dim=128,
                          memory_bank=None)

            print("Evaluation Time: '{}'s".format(time.time() - start))
            # writer.add_scalar('nn_acc', acc, epoch)
            logger.log_value('Test accuracy', acc, epoch)

            # print('accuracy: {}% \t (best acc: {}%)'.format(acc, best_acc))
            print('[Epoch]: {}'.format(epoch))
            print('accuracy: {}%)'.format(acc))
            # test_log_file.flush()

        # save model
        if epoch % args.save_freq == 0:
            print('==> Saving...')
            state = {
                'opt': args,
                'model': model.state_dict(),
                # 'contrast': contrast.state_dict(),
                'optimizer': optimizer.state_dict(),
                'epoch': epoch,
            }
            if args.contrastive_model == 'moco':
                state['model_ema'] = model_ema.state_dict()
            if args.amp:
                state['amp'] = amp.state_dict()
            save_file = os.path.join(
                args.model_folder,
                'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
            torch.save(state, save_file)
            # help release GPU memory
            del state

        # saving the model
        print('==> Saving...')
        state = {
            'opt': args,
            'model': model.state_dict(),
            # 'contrast': contrast.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch': epoch,
        }
        if args.contrastive_model == 'moco':
            state['model_ema'] = model_ema.state_dict()
        if args.amp:
            state['amp'] = amp.state_dict()
        save_file = os.path.join(args.model_folder, 'current.pth')
        torch.save(state, save_file)
        if epoch % args.save_freq == 0:
            save_file = os.path.join(
                args.model_folder,
                'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
            torch.save(state, save_file)
        # help release GPU memory
        del state
        torch.cuda.empty_cache()
Example #2
0
if P.lr_scheduler == 'cosine':
    scheduler = lr_scheduler.CosineAnnealingLR(optimizer, P.epochs)
elif P.lr_scheduler == 'step_decay':
    milestones = [int(0.5 * P.epochs), int(0.75 * P.epochs)]
    scheduler = lr_scheduler.MultiStepLR(optimizer, gamma=lr_decay_gamma, milestones=milestones)
else:
    raise NotImplementedError()

from training.scheduler import GradualWarmupScheduler
scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=10.0, total_epoch=P.warmup, after_scheduler=scheduler)

if P.resume_path is not None:
    resume = True
    model_state, optim_state, config = load_checkpoint(P.resume_path, mode='last')
    model.load_state_dict(model_state, strict=not P.no_strict)
    optimizer.load_state_dict(optim_state)
    start_epoch = config['epoch']
    best = config['best']
    error = 100.0
else:
    resume = False
    start_epoch = 1
    best = 100.0
    error = 100.0

if P.mode == 'sup_linear' or P.mode == 'sup_CSI_linear':
    assert P.load_path is not None
    checkpoint = torch.load(P.load_path)
    model.load_state_dict(checkpoint, strict=not P.no_strict)

if P.multi_gpu:
Example #3
0
base_optimizer = optim.Adam(model.parameters(),
                            lr=args.lr,
                            weight_decay=args.decay_lr)
optimizer = LARS(optimizer=base_optimizer, eps=1e-8, trust_coef=0.001)
scheduler = ExponentialLR(optimizer, gamma=args.decay_lr)

# Main training loop
best_loss = np.inf

# Resume training
if args.load_model is not None:
    if os.path.isfile(args.load_model):
        checkpoint = torch.load(args.load_model)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        base_optimizer.load_state_dict(checkpoint['base_optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        best_loss = checkpoint['val_loss']
        epoch = checkpoint['epoch']
        print('Loading model: {}. Resuming from epoch: {}'.format(
            args.load_model, epoch))
    else:
        print('Model: {} not found'.format(args.load_model))

for epoch in range(args.epochs):
    v_loss = execute_graph(model, loader, optimizer, scheduler, epoch,
                           use_cuda)

    if v_loss < best_loss:
        best_loss = v_loss
Example #4
0
def train(cfg, writer, logger):
    
    # Setup random seeds to a determinated value for reproduction
    # seed = 1337
    # torch.manual_seed(seed)
    # torch.cuda.manual_seed(seed)
    # np.random.seed(seed)
    # random.seed(seed)
    # np.random.default_rng(seed)

    # Setup Augmentations
    augmentations = cfg.train.augment
    logger.info(f'using augments: {augmentations}')
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg.data.dataloader)
    data_path = cfg.data.path

    logger.info("Using dataset: {}".format(data_path))

    t_loader = data_loader(
        data_path,
        # transform=None,
        # time_shuffle = cfg.data.time_shuffle,
        # to_tensor=False,
        data_format = cfg.data.format,
        split=cfg.data.train_split,
        norm = cfg.data.norm,
        augments=data_aug
        )

    v_loader = data_loader(
        data_path,
        # transform=None,
        # time_shuffle = cfg.data.time_shuffle,
        # to_tensor=False,
        data_format = cfg.data.format,
        split=cfg.data.val_split,
        )
    train_data_len = len(t_loader)
    logger.info(f'num of train samples: {train_data_len} \nnum of val samples: {len(v_loader)}')

    batch_size = cfg.train.batch_size
    epoch = cfg.train.epoch
    train_iter = int(np.ceil(train_data_len / batch_size) * epoch)
    logger.info(f'total train iter: {train_iter}')

    trainloader = data.DataLoader(t_loader,
                                  batch_size=batch_size, 
                                  num_workers=cfg.train.n_workers, 
                                  shuffle=True,
                                  persistent_workers=True,
                                  drop_last=True)

    valloader = data.DataLoader(v_loader, 
                                batch_size=10, 
                                # persis
                                num_workers=cfg.train.n_workers,)

    # Setup Model
    device = f'cuda:{cfg.gpu[0]}'
    model = get_model(cfg.model, 2).to(device)
    input_size = (cfg.model.input_nbr, 512, 512)
    logger.info(f"Using Model: {cfg.model.arch}")
    # logger.info(f'model summary: {summary(model, input_size=(input_size, input_size), is_complex=True)}')
    model = torch.nn.DataParallel(model, device_ids=cfg.gpu)      #自动多卡运行,这个好用
    
    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {k:v for k, v in vars(cfg.train.optimizer).items()
                        if k not in ('name', 'wrap')}
    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))
    if hasattr(cfg.train.optimizer, 'warp') and cfg.train.optimizer.wrap=='lars':
        optimizer = LARS(optimizer=optimizer)
        logger.info(f'warp optimizer with {cfg.train.optimizer.wrap}')
    scheduler = get_scheduler(optimizer, cfg.train.lr)
    loss_fn = get_loss_function(cfg)
    logger.info(f"Using loss ,{str(cfg.train.loss)}")

    # load checkpoints
    val_cls_1_acc = 0
    best_cls_1_acc_now = 0
    best_cls_1_acc_iter_now = 0
    val_macro_OA = 0
    best_macro_OA_now = 0
    best_macro_OA_iter_now = 0
    start_iter = 0
    if cfg.train.resume is not None:
        if os.path.isfile(cfg.train.resume):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(cfg.train.resume)
            )

            # load model state
            checkpoint = torch.load(cfg.train.resume)
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            # best_cls_1_acc_now = checkpoint["best_cls_1_acc_now"]
            # best_cls_1_acc_iter_now = checkpoint["best_cls_1_acc_iter_now"]

            start_iter = checkpoint["epoch"]
            logger.info(
                "Loaded checkpoint '{}' (iter {})".format(
                    cfg.train.resume, checkpoint["epoch"]
                )
            )

            # copy tensorboard files
            resume_src_dir = osp.split(cfg.train.resume)[0]
            # shutil.copytree(resume_src_dir, writer.get_logdir())
            for file in os.listdir(resume_src_dir):
                if not ('.log' in file or '.yml' in file or '_last_model' in file):
                # if 'events.out.tfevents' in file:
                    resume_dst_dir = writer.get_logdir()
                    fu.copy(osp.join(resume_src_dir, file), resume_dst_dir, )

        else:
            logger.info("No checkpoint found at '{}'".format(cfg.train.resume))

    # Setup Metrics
    running_metrics_val = runningScore(2)
    runing_metrics_train = runningScore(2)
    val_loss_meter = averageMeter()
    train_time_meter = averageMeter()

    # train
    it = start_iter
    train_start_time = time.time() 
    train_val_start_time = time.time()
    model.train()   
    while it < train_iter:
        for (file_a, file_b, label, mask) in trainloader:
            it += 1           
            file_a = file_a.to(device)            
            file_b = file_b.to(device)            
            label = label.to(device)            
            mask = mask.to(device)

            optimizer.zero_grad()
            # print(f'dtype: {file_a.dtype}')
            outputs = model(file_a, file_b)
            loss = loss_fn(input=outputs, target=label, mask=mask)
            loss.backward()

            # print('conv11: ', model.conv11.weight.grad, model.conv11.weight.grad.shape)
            # print('conv21: ', model.conv21.weight.grad, model.conv21.weight.grad.shape)
            # print('conv31: ', model.conv31.weight.grad, model.conv31.weight.grad.shape)

            # In PyTorch 1.1.0 and later, you should call `optimizer.step()` before `lr_scheduler.step()`
            optimizer.step()
            scheduler.step()
            
            # record the acc of the minibatch
            pred = outputs.max(1)[1].cpu().numpy()
            runing_metrics_train.update(label.cpu().numpy(), pred, mask.cpu().numpy())

            train_time_meter.update(time.time() - train_start_time)

            if it % cfg.train.print_interval == 0:
                # acc of the samples between print_interval
                score, _ = runing_metrics_train.get_scores()
                train_cls_0_acc, train_cls_1_acc = score['Acc']
                fmt_str = "Iter [{:d}/{:d}]  train Loss: {:.4f}  Time/Image: {:.4f},\n0:{:.4f}\n1:{:.4f}"
                print_str = fmt_str.format(it,
                                           train_iter,
                                           loss.item(),      #extracts the loss’s value as a Python float.
                                           train_time_meter.avg / cfg.train.batch_size,train_cls_0_acc, train_cls_1_acc)
                runing_metrics_train.reset()
                train_time_meter.reset()
                logger.info(print_str)
                writer.add_scalar('loss/train_loss', loss.item(), it)
                writer.add_scalars('metrics/train', {'cls_0':train_cls_0_acc, 'cls_1':train_cls_1_acc}, it)
                # writer.add_scalar('train_metrics/acc/cls_0', train_cls_0_acc, it)
                # writer.add_scalar('train_metrics/acc/cls_1', train_cls_1_acc, it)

            if it % cfg.train.val_interval == 0 or \
               it == train_iter:
                val_start_time = time.time()
                model.eval()            # change behavior like drop out
                with torch.no_grad():   # disable autograd, save memory usage
                    for (file_a_val, file_b_val, label_val, mask_val) in valloader:      
                        file_a_val = file_a_val.to(device)            
                        file_b_val = file_b_val.to(device)

                        outputs = model(file_a_val, file_b_val)
                        # tensor.max() returns the maximum value and its indices
                        pred = outputs.max(1)[1].cpu().numpy()
                        running_metrics_val.update(label_val.numpy(), pred, mask_val.numpy())
            
                        label_val = label_val.to(device)            
                        mask_val = mask_val.to(device)
                        val_loss = loss_fn(input=outputs, target=label_val, mask=mask_val)
                        val_loss_meter.update(val_loss.item())

                score, _ = running_metrics_val.get_scores()
                val_cls_0_acc, val_cls_1_acc = score['Acc']

                writer.add_scalar('loss/val_loss', val_loss_meter.avg, it)
                logger.info(f"Iter [{it}/{train_iter}], val Loss: {val_loss_meter.avg:.4f} Time/Image: {(time.time()-val_start_time)/len(v_loader):.4f}\n0: {val_cls_0_acc:.4f}\n1:{val_cls_1_acc:.4f}")
                # lr_now = optimizer.param_groups[0]['lr']
                # logger.info(f'lr: {lr_now}')
                # writer.add_scalar('lr', lr_now, it+1)

                logger.info('0: {:.4f}\n1:{:.4f}'.format(val_cls_0_acc, val_cls_1_acc))
                writer.add_scalars('metrics/val', {'cls_0':val_cls_0_acc, 'cls_1':val_cls_1_acc}, it)
                # writer.add_scalar('val_metrics/acc/cls_0', val_cls_0_acc, it)
                # writer.add_scalar('val_metrics/acc/cls_1', val_cls_1_acc, it)

                val_loss_meter.reset()
                running_metrics_val.reset()

                # OA=score["Overall_Acc"]
                val_macro_OA = (val_cls_0_acc+val_cls_1_acc)/2
                if val_macro_OA >= best_macro_OA_now and it>200:
                    best_macro_OA_now = val_macro_OA
                    best_macro_OA_iter_now = it
                    state = {
                        "epoch": it,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "scheduler_state": scheduler.state_dict(),
                        "best_macro_OA_now": best_macro_OA_now,
                        'best_macro_OA_iter_now':best_macro_OA_iter_now,
                    }
                    save_path = os.path.join(writer.file_writer.get_logdir(), "{}_{}_best_model.pkl".format(cfg.model.arch,cfg.data.dataloader))
                    torch.save(state, save_path)

                    logger.info("best OA now =  %.8f" % (best_macro_OA_now))
                    logger.info("best OA iter now= %d" % (best_macro_OA_iter_now))

                train_val_time = time.time() - train_val_start_time
                remain_time = train_val_time * (train_iter-it) / it
                m, s = divmod(remain_time, 60)
                h, m = divmod(m, 60)
                if s != 0:
                    train_time = "Remain train time = %d hours %d minutes %d seconds \n" % (h, m, s)
                else:
                    train_time = "Remain train time : train completed.\n"
                logger.info(train_time)
                model.train()

            train_start_time = time.time() 

    logger.info("best OA now =  %.8f" % (best_macro_OA_now))
    logger.info("best OA iter now= %d" % (best_macro_OA_iter_now))

    state = {
            "epoch": it,
            "model_state": model.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "scheduler_state": scheduler.state_dict(),
            "best_macro_OA_now": best_macro_OA_now,
            'best_macro_OA_iter_now':best_macro_OA_iter_now,
            }
    save_path = os.path.join(writer.file_writer.get_logdir(), "{}_{}_last_model.pkl".format(cfg.model.arch, cfg.data.dataloader))
    torch.save(state, save_path)
Example #5
0
def train_and_eval(tag,
                   dataroot,
                   test_ratio=0.0,
                   cv_fold=0,
                   reporter=None,
                   metric='last',
                   save_path=None,
                   only_eval=False):
    if not reporter:
        reporter = lambda **kwargs: 0

    max_epoch = C.get()['epoch']
    trainsampler, trainloader, validloader, testloader_ = get_dataloaders(
        C.get()['dataset'],
        C.get()['batch'],
        dataroot,
        test_ratio,
        split_idx=cv_fold)

    # create a model & an optimizer
    model = get_model(C.get()['model'], num_class(C.get()['dataset']))

    lb_smooth = C.get()['optimizer'].get('label_smoothing', 0.0)
    if lb_smooth > 0.0:
        criterion = SmoothCrossEntropyLoss(lb_smooth)
    else:
        criterion = nn.CrossEntropyLoss()
    if C.get()['optimizer']['type'] == 'sgd':
        optimizer = optim.SGD(model.parameters(),
                              lr=C.get()['lr'],
                              momentum=C.get()['optimizer'].get(
                                  'momentum', 0.9),
                              weight_decay=C.get()['optimizer']['decay'],
                              nesterov=C.get()['optimizer']['nesterov'])
    else:
        raise ValueError('invalid optimizer type=%s' %
                         C.get()['optimizer']['type'])

    if C.get()['optimizer'].get('lars', False):
        from torchlars import LARS
        optimizer = LARS(optimizer)
        logger.info('*** LARS Enabled.')

    lr_scheduler_type = C.get()['lr_schedule'].get('type', 'cosine')
    if lr_scheduler_type == 'cosine':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=C.get()['epoch'], eta_min=0.)
    elif lr_scheduler_type == 'resnet':
        scheduler = adjust_learning_rate_resnet(optimizer)
    else:
        raise ValueError('invalid lr_schduler=%s' % lr_scheduler_type)

    if C.get()['lr_schedule'].get('warmup', None):
        scheduler = GradualWarmupScheduler(
            optimizer,
            multiplier=C.get()['lr_schedule']['warmup']['multiplier'],
            total_epoch=C.get()['lr_schedule']['warmup']['epoch'],
            after_scheduler=scheduler)

    if not tag:
        from RandAugment.metrics import SummaryWriterDummy as SummaryWriter
        logger.warning('tag not provided, no tensorboard log.')
    else:
        from tensorboardX import SummaryWriter
    writers = [
        SummaryWriter(log_dir='./logs/%s/%s' % (tag, x))
        for x in ['train', 'valid', 'test']
    ]

    result = OrderedDict()
    epoch_start = 1
    if save_path and os.path.exists(save_path):
        logger.info('%s file found. loading...' % save_path)
        data = torch.load(save_path)
        if 'model' in data or 'state_dict' in data:
            key = 'model' if 'model' in data else 'state_dict'
            logger.info('checkpoint epoch@%d' % data['epoch'])
            if not isinstance(model, DataParallel):
                model.load_state_dict({
                    k.replace('module.', ''): v
                    for k, v in data[key].items()
                })
            else:
                model.load_state_dict({
                    k if 'module.' in k else 'module.' + k: v
                    for k, v in data[key].items()
                })
            optimizer.load_state_dict(data['optimizer'])
            if data['epoch'] < C.get()['epoch']:
                epoch_start = data['epoch']
            else:
                only_eval = True
        else:
            model.load_state_dict({k: v for k, v in data.items()})
        del data
    else:
        logger.info('"%s" file not found. skip to pretrain weights...' %
                    save_path)
        if only_eval:
            logger.warning(
                'model checkpoint not found. only-evaluation mode is off.')
        only_eval = False

    if only_eval:
        logger.info('evaluation only+')
        model.eval()
        rs = dict()
        rs['train'] = run_epoch(model,
                                trainloader,
                                criterion,
                                None,
                                desc_default='train',
                                epoch=0,
                                writer=writers[0])
        rs['valid'] = run_epoch(model,
                                validloader,
                                criterion,
                                None,
                                desc_default='valid',
                                epoch=0,
                                writer=writers[1])
        rs['test'] = run_epoch(model,
                               testloader_,
                               criterion,
                               None,
                               desc_default='*test',
                               epoch=0,
                               writer=writers[2])
        for key, setname in itertools.product(['loss', 'top1', 'top5'],
                                              ['train', 'valid', 'test']):
            if setname not in rs:
                continue
            result['%s_%s' % (key, setname)] = rs[setname][key]
        result['epoch'] = 0
        return result

    # train loop
    best_top1 = 0
    for epoch in range(epoch_start, max_epoch + 1):
        model.train()
        rs = dict()
        rs['train'] = run_epoch(model,
                                trainloader,
                                criterion,
                                optimizer,
                                desc_default='train',
                                epoch=epoch,
                                writer=writers[0],
                                verbose=True,
                                scheduler=scheduler)
        model.eval()

        if math.isnan(rs['train']['loss']):
            raise Exception('train loss is NaN.')

        if epoch % 5 == 0 or epoch == max_epoch:
            rs['valid'] = run_epoch(model,
                                    validloader,
                                    criterion,
                                    None,
                                    desc_default='valid',
                                    epoch=epoch,
                                    writer=writers[1],
                                    verbose=True)
            rs['test'] = run_epoch(model,
                                   testloader_,
                                   criterion,
                                   None,
                                   desc_default='*test',
                                   epoch=epoch,
                                   writer=writers[2],
                                   verbose=True)

            if metric == 'last' or rs[metric]['top1'] > best_top1:
                if metric != 'last':
                    best_top1 = rs[metric]['top1']
                for key, setname in itertools.product(
                    ['loss', 'top1', 'top5'], ['train', 'valid', 'test']):
                    result['%s_%s' % (key, setname)] = rs[setname][key]
                result['epoch'] = epoch

                writers[1].add_scalar('valid_top1/best', rs['valid']['top1'],
                                      epoch)
                writers[2].add_scalar('test_top1/best', rs['test']['top1'],
                                      epoch)

                reporter(loss_valid=rs['valid']['loss'],
                         top1_valid=rs['valid']['top1'],
                         loss_test=rs['test']['loss'],
                         top1_test=rs['test']['top1'])

                # save checkpoint
                if save_path:
                    logger.info('save model@%d to %s' % (epoch, save_path))
                    torch.save(
                        {
                            'epoch': epoch,
                            'log': {
                                'train': rs['train'].get_dict(),
                                'valid': rs['valid'].get_dict(),
                                'test': rs['test'].get_dict(),
                            },
                            'optimizer': optimizer.state_dict(),
                            'model': model.state_dict()
                        }, save_path)
                    torch.save(
                        {
                            'epoch': epoch,
                            'log': {
                                'train': rs['train'].get_dict(),
                                'valid': rs['valid'].get_dict(),
                                'test': rs['test'].get_dict(),
                            },
                            'optimizer': optimizer.state_dict(),
                            'model': model.state_dict()
                        },
                        save_path.replace(
                            '.pth', '_e%d_top1_%.3f_%.3f' %
                            (epoch, rs['train']['top1'], rs['test']['top1']) +
                            '.pth'))

    del model

    result['top1_test'] = best_top1
    return result
Example #6
0
def train(cfg, writer, logger):

    # Setup Augmentations
    augmentations = cfg.train.augment
    logger.info(f'using augments: {augmentations}')
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg.data.dataloader)
    data_path = cfg.data.path
    logger.info("data path: {}".format(data_path))

    t_loader = data_loader(
        data_path,
        data_format=cfg.data.format,
        norm=cfg.data.norm,
        split='train',
        split_root=cfg.data.split,
        augments=data_aug,
        logger=logger,
        log=cfg.data.log,
        ENL=cfg.data.ENL,
    )

    v_loader = data_loader(
        data_path,
        data_format=cfg.data.format,
        split='val',
        log=cfg.data.log,
        split_root=cfg.data.split,
        logger=logger,
        ENL=cfg.data.ENL,
    )

    train_data_len = len(t_loader)
    logger.info(
        f'num of train samples: {train_data_len} \nnum of val samples: {len(v_loader)}'
    )

    batch_size = cfg.train.batch_size
    epoch = cfg.train.epoch
    train_iter = int(np.ceil(train_data_len / batch_size) * epoch)
    logger.info(f'total train iter: {train_iter}')

    trainloader = data.DataLoader(t_loader,
                                  batch_size=batch_size,
                                  num_workers=cfg.train.n_workers,
                                  shuffle=True,
                                  persistent_workers=True,
                                  drop_last=True)

    valloader = data.DataLoader(
        v_loader,
        batch_size=cfg.test.batch_size,
        # persis
        num_workers=cfg.train.n_workers,
    )

    # Setup Model
    device = f'cuda:{cfg.train.gpu[0]}'
    model = get_model(cfg.model).to(device)
    input_size = (cfg.model.in_channels, 512, 512)
    logger.info(f"Using Model: {cfg.model.arch}")
    # logger.info(f'model summary: {summary(model, input_size=(input_size, input_size), is_complex=False)}')
    model = torch.nn.DataParallel(model, device_ids=cfg.gpu)  #自动多卡运行,这个好用

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in vars(cfg.train.optimizer).items()
        if k not in ('name', 'wrap')
    }
    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))
    if hasattr(cfg.train.optimizer,
               'wrap') and cfg.train.optimizer.wrap == 'lars':
        optimizer = LARS(optimizer=optimizer)
        logger.info(f'warp optimizer with {cfg.train.optimizer.wrap}')
    scheduler = get_scheduler(optimizer, cfg.train.lr)
    # loss_fn = get_loss_function(cfg)
    # logger.info(f"Using loss ,{str(cfg.train.loss)}")

    # load checkpoints
    val_cls_1_acc = 0
    best_cls_1_acc_now = 0
    best_cls_1_acc_iter_now = 0
    val_macro_OA = 0
    best_macro_OA_now = 0
    best_macro_OA_iter_now = 0
    start_iter = 0
    if cfg.train.resume is not None:
        if os.path.isfile(cfg.train.resume):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg.train.resume))

            # load model state
            checkpoint = torch.load(cfg.train.resume)
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            # best_cls_1_acc_now = checkpoint["best_cls_1_acc_now"]
            # best_cls_1_acc_iter_now = checkpoint["best_cls_1_acc_iter_now"]

            start_iter = checkpoint["epoch"]
            logger.info("Loaded checkpoint '{}' (iter {})".format(
                cfg.train.resume, checkpoint["epoch"]))

            # copy tensorboard files
            resume_src_dir = osp.split(cfg.train.resume)[0]
            # shutil.copytree(resume_src_dir, writer.get_logdir())
            for file in os.listdir(resume_src_dir):
                if not ('.log' in file or '.yml' in file
                        or '_last_model' in file):
                    # if 'events.out.tfevents' in file:
                    resume_dst_dir = writer.get_logdir()
                    fu.copy(
                        osp.join(resume_src_dir, file),
                        resume_dst_dir,
                    )

        else:
            logger.info("No checkpoint found at '{}'".format(cfg.train.resume))

    data_range = 255
    if cfg.data.log:
        data_range = np.log(data_range)
    # data_range /= 350

    # Setup Metrics
    running_metrics_val = runningScore(2)
    runing_metrics_train = runningScore(2)
    val_loss_meter = averageMeter()
    train_time_meter = averageMeter()
    train_loss_meter = averageMeter()
    val_psnr_meter = averageMeter()
    val_ssim_meter = averageMeter()

    # train
    it = start_iter
    train_start_time = time.time()
    train_val_start_time = time.time()
    model.train()
    while it < train_iter:
        for clean, noisy, _ in trainloader:
            it += 1

            noisy = noisy.to(device, dtype=torch.float32)
            # noisy /= 350
            mask1, mask2 = rand_pool.generate_mask_pair(noisy)
            noisy_sub1 = rand_pool.generate_subimages(noisy, mask1)
            noisy_sub2 = rand_pool.generate_subimages(noisy, mask2)

            # preparing for the regularization term
            with torch.no_grad():
                noisy_denoised = model(noisy)
            noisy_sub1_denoised = rand_pool.generate_subimages(
                noisy_denoised, mask1)
            noisy_sub2_denoised = rand_pool.generate_subimages(
                noisy_denoised, mask2)
            # print(rand_pool.operation_seed_counter)

            # for ii, param in enumerate(model.parameters()):
            #     if torch.sum(torch.isnan(param.data)):
            #         print(f'{ii}: nan parameters')

            # calculating the loss
            noisy_output = model(noisy_sub1)
            noisy_target = noisy_sub2
            if cfg.train.loss.gamma.const:
                gamma = cfg.train.loss.gamma.base
            else:
                gamma = it / train_iter * cfg.train.loss.gamma.base

            diff = noisy_output - noisy_target
            exp_diff = noisy_sub1_denoised - noisy_sub2_denoised
            loss1 = torch.mean(diff**2)
            loss2 = gamma * torch.mean((diff - exp_diff)**2)
            loss_all = loss1 + loss2

            # loss1 = noisy_output - noisy_target
            # loss2 = torch.exp(noisy_target - noisy_output)
            # loss_all = torch.mean(loss1 + loss2)
            loss_all.backward()

            # In PyTorch 1.1.0 and later, you should call `optimizer.step()` before `lr_scheduler.step()`
            optimizer.step()
            scheduler.step()

            # record the loss of the minibatch
            train_loss_meter.update(loss_all)
            train_time_meter.update(time.time() - train_start_time)
            writer.add_scalar('lr', optimizer.param_groups[0]['lr'], it)

            if it % 1000 == 0:
                writer.add_histogram('hist/pred', noisy_denoised, it)
                writer.add_histogram('hist/noisy', noisy, it)

                if cfg.data.simulate:
                    writer.add_histogram('hist/clean', clean, it)

            if cfg.data.simulate:
                pass

            # print interval
            if it % cfg.train.print_interval == 0:
                terminal_info = f"Iter [{it:d}/{train_iter:d}]  \
                                train Loss: {train_loss_meter.avg:.4f}  \
                                Time/Image: {train_time_meter.avg / cfg.train.batch_size:.4f}"

                logger.info(terminal_info)
                writer.add_scalar('loss/train_loss', train_loss_meter.avg, it)

                if cfg.data.simulate:
                    pass

                runing_metrics_train.reset()
                train_time_meter.reset()
                train_loss_meter.reset()

            # val interval
            if it % cfg.train.val_interval == 0 or \
               it == train_iter:
                val_start_time = time.time()
                model.eval()
                with torch.no_grad():
                    for clean, noisy, _ in valloader:
                        # noisy /= 350
                        # clean /= 350
                        noisy = noisy.to(device, dtype=torch.float32)
                        noisy_denoised = model(noisy)

                        if cfg.data.simulate:
                            clean = clean.to(device, dtype=torch.float32)
                            psnr = piq.psnr(clean,
                                            noisy_denoised,
                                            data_range=data_range)
                            ssim = piq.ssim(clean,
                                            noisy_denoised,
                                            data_range=data_range)
                            val_psnr_meter.update(psnr)
                            val_ssim_meter.update(ssim)

                        val_loss = torch.mean((noisy_denoised - noisy)**2)
                        val_loss_meter.update(val_loss)

                writer.add_scalar('loss/val_loss', val_loss_meter.avg, it)
                logger.info(
                    f"Iter [{it}/{train_iter}], val Loss: {val_loss_meter.avg:.4f} Time/Image: {(time.time()-val_start_time)/len(v_loader):.4f}"
                )
                val_loss_meter.reset()
                running_metrics_val.reset()

                if cfg.data.simulate:
                    writer.add_scalars('metrics/val', {
                        'psnr': val_psnr_meter.avg,
                        'ssim': val_ssim_meter.avg
                    }, it)
                    logger.info(
                        f'psnr: {val_psnr_meter.avg},\tssim: {val_ssim_meter.avg}'
                    )
                    val_psnr_meter.reset()
                    val_ssim_meter.reset()

                train_val_time = time.time() - train_val_start_time
                remain_time = train_val_time * (train_iter - it) / it
                m, s = divmod(remain_time, 60)
                h, m = divmod(m, 60)
                if s != 0:
                    train_time = "Remain train time = %d hours %d minutes %d seconds \n" % (
                        h, m, s)
                else:
                    train_time = "Remain train time : train completed.\n"
                logger.info(train_time)
                model.train()

            # save model
            if it % (train_iter / cfg.train.epoch * 10) == 0:
                ep = int(it / (train_iter / cfg.train.epoch))
                state = {
                    "epoch": it,
                    "model_state": model.state_dict(),
                    "optimizer_state": optimizer.state_dict(),
                    "scheduler_state": scheduler.state_dict(),
                }
                save_path = osp.join(writer.file_writer.get_logdir(),
                                     f"{ep}.pkl")
                torch.save(state, save_path)
                logger.info(f'saved model state dict at {save_path}')

            train_start_time = time.time()
def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu
    
    if args.distributed:
        if args.dist_url == 'env://' and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank)
    
    model = PixPro(
                encoder=resnet50, 
                dim1 = args.pcl_dim_1, 
                dim2 = args.pcl_dim_2, 
                momentum = args.encoder_momentum,
                threshold = args.threshold,
                temperature = args.T,
                sharpness = args.sharpness ,
                num_linear = args.num_linear,
                )

    args.lr = args.lr_base * args.batch_size/256
    
    if args.distributed:
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)

            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int((args.workers + ngpus_per_node -1) / ngpus_per_node)
            
            # convert batch norm --> sync batch norm
            sync_bn_model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])

        else:
            model.cuda()
            model = torch.nn.parallel.DistributedDataParallel(moel)

    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        raise NotImplementedError('only DDP is supported.')

    base_optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    optimizer = LARS(optimizer=base_optimizer, eps=1e-8)
    writer = SummaryWriter(args.log_dir) 
    if args.resume:
        checkpoint = torch.load(args.resume)
        args.start_epoch = checkpoint['epoch']

        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])

    cudnn.benchmark = True

    dataset = PixProDataset(root=args.train_path, args=args)
    
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
    else:
        train_sampler = None
    
    loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
                    num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True)
    
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        
        adjust_lr(optimizer, epoch, args)
        train(args, epoch, loader, model, optimizer, writer)


        if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0):
            save_name = '{}.pth.tar'.format(epoch)
            save_name = os.path.join(args.checkpoint_dir, save_name)
            torch.save({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer' : optimizer.state_dict(),
                }, save_name)