def __init__(self, device, train_loader, val_loader, model, MODEL_PATH):
        self.device = device
        self.model_path = MODEL_PATH
        self.train_loader = train_loader
        self.val_loader = val_loader

        self.model = model.to(device)
        base_optimizer = optim.Adam(self.model.parameters(), lr=2e-4)
        self.optimizer = LARS(optimizer=base_optimizer,
                              eps=1e-8,
                              trust_coef=0.001)
        self.criterion = nn.BCEWithLogitsLoss()
        self.scaler = GradScaler()
Example #2
0
def test_pickle():
    sgd = SGD([torch.Tensor()], lr=0.1)
    lars = LARS(sgd, trust_coef=0.42)
    lars2 = pickle.loads(pickle.dumps(lars))

    assert isinstance(lars2.optim, SGD)
    assert lars2.optim.param_groups[0]['lr'] == 0.1
    assert lars2.trust_coef == 0.42
Example #3
0
def test_deepcopy():
    sgd = SGD([torch.Tensor()], lr=0.1)
    lars = LARS(sgd, trust_coef=0.42)
    lars2 = copy.deepcopy(lars)

    assert isinstance(lars2.optim, SGD)
    assert lars2.optim.param_groups[0]['lr'] == 0.1
    assert lars2.trust_coef == 0.42
Example #4
0
    model = DocumentReader(args.hidden_dim,
                           args.emb_dim,
                           args.layers,
                           NUM_DIRECTIONS,
                           args.dropout,
                           device).to(device)
    model = torch.nn.DataParallel(model)

    writer = SummaryWriter(comment="_nlp_%s_%s_%s" % (args.optimizer, args.batch_size, args.learning_rate))
    weight_decay = args.learning_rate / args.epochs
    if args.optimizer == 'lamb':
        optimizer = Lamb(model.parameters(), lr=args.learning_rate, weight_decay=weight_decay,
                         betas=(.9, .999), adam=False, writer=writer)
    elif args.optimizer == 'lars':
        base_optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=weight_decay)
        optimizer = LARS(optimizer=base_optimizer, eps=1e-8, trust_coef=0.001, writer=writer)
    elif args.optimizer == 'sgd':
        optimizer = SGD(model.parameters(), lr=args.learning_rate, momentum=0.9,
                        weight_decay=weight_decay, writer=writer)
    else:
        # use adam optimizer
        optimizer = Lamb(model.parameters(), lr=args.learning_rate, weight_decay=weight_decay,
                         betas=(.9, .999), adam=True, writer=writer)
    print(f'The model has {count_parameters(model):,} trainable parameters')
    ckpt_dir_name = "%s_%s_%s" % (args.working_dir, args.optimizer, args.batch_size)
    model, optimizer = load_pretrained_model(model, optimizer,
                                             "%s/ckpt/%s" % (ckpt_dir_name, "best_weights.pt"))

    print(args)
    ckpt_dir = os.path.join(ckpt_dir_name, 'ckpt')
    os.makedirs(ckpt_dir, exist_ok=True)
Example #5
0
def main():

    args = parse_option()

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    # set the data loader
    data_folder = os.path.join(args.data_folder, 'train')
    val_folder = os.path.join(args.data_folder, 'val')

    crop_padding = 32
    image_size = 224
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    normalize = transforms.Normalize(mean=mean, std=std)

    if args.aug == 'NULL' and args.dataset == 'imagenet':
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(image_size, scale=(args.crop, 1.)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    elif args.aug == 'CJ':
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(image_size, scale=(args.crop, 1.)),
            transforms.RandomGrayscale(p=0.2),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    # elif args.aug == 'NULL' and args.dataset == 'cifar':
    #     train_transform = transforms.Compose([
    #         transforms.RandomResizedCrop(size=32, scale=(0.2, 1.)),
    #         transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
    #         transforms.RandomGrayscale(p=0.2),
    #         transforms.RandomHorizontalFlip(p=0.5),
    #         transforms.ToTensor(),
    #         transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    #     ])
    #
    #     test_transform = transforms.Compose([
    #         transforms.ToTensor(),
    #         transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    #     ])
    elif args.aug == 'simple' and args.dataset == 'imagenet':
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(image_size, scale=(args.crop, 1.)),
            transforms.RandomHorizontalFlip(),
            get_color_distortion(1.0),
            transforms.ToTensor(),
            normalize,
        ])

        # TODO: Currently follow CMC
        test_transform = transforms.Compose([
            transforms.Resize(image_size + crop_padding),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            normalize,
        ])
    elif args.aug == 'simple' and args.dataset == 'cifar':
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(size=32),
            transforms.RandomHorizontalFlip(p=0.5),
            get_color_distortion(0.5),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        test_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

    else:
        raise NotImplemented('augmentation not supported: {}'.format(args.aug))

    # Get Datasets
    if args.dataset == "imagenet":
        train_dataset = ImageFolderInstance(data_folder,
                                            transform=train_transform,
                                            two_crop=args.moco)
        print(len(train_dataset))
        train_sampler = None
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=(train_sampler is None),
            num_workers=args.num_workers,
            pin_memory=True,
            sampler=train_sampler)

        test_dataset = datasets.ImageFolder(val_folder,
                                            transforms=test_transform)

        test_loader = torch.utils.data.DataLoader(test_dataset,
                                                  batch_size=256,
                                                  shuffle=False,
                                                  num_workers=args.num_workers,
                                                  pin_memory=True)

    elif args.dataset == 'cifar':
        # cifar-10 dataset
        if args.contrastive_model == 'simclr':
            train_dataset = CIFAR10Instance_double(root='./data',
                                                   train=True,
                                                   download=True,
                                                   transform=train_transform,
                                                   double=True)
        else:
            train_dataset = CIFAR10Instance(root='./data',
                                            train=True,
                                            download=True,
                                            transform=train_transform)
        train_sampler = None
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=(train_sampler is None),
            num_workers=args.num_workers,
            pin_memory=True,
            sampler=train_sampler,
            drop_last=True)

        test_dataset = CIFAR10Instance(root='./data',
                                       train=False,
                                       download=True,
                                       transform=test_transform)
        test_loader = torch.utils.data.DataLoader(test_dataset,
                                                  batch_size=100,
                                                  shuffle=False,
                                                  num_workers=args.num_workers)

    # create model and optimizer
    n_data = len(train_dataset)

    if args.model == 'resnet50':
        model = InsResNet50()
        if args.contrastive_model == 'moco':
            model_ema = InsResNet50()
    elif args.model == 'resnet50x2':
        model = InsResNet50(width=2)
        if args.contrastive_model == 'moco':
            model_ema = InsResNet50(width=2)
    elif args.model == 'resnet50x4':
        model = InsResNet50(width=4)
        if args.contrastive_model == 'moco':
            model_ema = InsResNet50(width=4)
    elif args.model == 'resnet50_cifar':
        model = InsResNet50_cifar()
        if args.contrastive_model == 'moco':
            model_ema = InsResNet50_cifar()
    else:
        raise NotImplementedError('model not supported {}'.format(args.model))

    # copy weights from `model' to `model_ema'
    if args.contrastive_model == 'moco':
        moment_update(model, model_ema, 0)

    # set the contrast memory and criterion
    if args.contrastive_model == 'moco':
        contrast = MemoryMoCo(128, n_data, args.nce_k, args.nce_t,
                              args.softmax).cuda(args.gpu)
    elif args.contrastive_model == 'simclr':
        contrast = None
    else:
        contrast = MemoryInsDis(128, n_data, args.nce_k, args.nce_t,
                                args.nce_m, args.softmax).cuda(args.gpu)

    if args.softmax:
        criterion = NCESoftmaxLoss()
    elif args.contrastive_model == 'simclr':
        criterion = BatchCriterion(1, args.nce_t, args.batch_size)
    else:
        criterion = NCECriterion(n_data)
    criterion = criterion.cuda(args.gpu)

    model = model.cuda()
    if args.contrastive_model == 'moco':
        model_ema = model_ema.cuda()

    # Exclude BN and bias if needed
    weight_decay = args.weight_decay
    if weight_decay and args.filter_weight_decay:
        parameters = add_weight_decay(model, weight_decay,
                                      args.filter_weight_decay)
        weight_decay = 0.
    else:
        parameters = model.parameters()

    optimizer = torch.optim.SGD(parameters,
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=weight_decay)
    cudnn.benchmark = True

    if args.amp:
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.opt_level)
        if args.contrastive_model == 'moco':
            optimizer_ema = torch.optim.SGD(model_ema.parameters(),
                                            lr=0,
                                            momentum=0,
                                            weight_decay=0)
            model_ema, optimizer_ema = amp.initialize(model_ema,
                                                      optimizer_ema,
                                                      opt_level=args.opt_level)

    if args.LARS:
        optimizer = LARS(optimizer=optimizer, eps=1e-8, trust_coef=0.001)

    # optionally resume from a checkpoint
    args.start_epoch = 0
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cpu')
            # checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch'] + 1
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            if contrast:
                contrast.load_state_dict(checkpoint['contrast'])
            if args.contrastive_model == 'moco':
                model_ema.load_state_dict(checkpoint['model_ema'])

            if args.amp and checkpoint['opt'].amp:
                print('==> resuming amp state_dict')
                amp.load_state_dict(checkpoint['amp'])

            print("=> loaded successfully '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            del checkpoint
            torch.cuda.empty_cache()
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # tensorboard
    logger = tb_logger.Logger(logdir=args.tb_folder, flush_secs=2)

    # routine
    for epoch in range(args.start_epoch, args.epochs + 1):

        print("==> training...")

        time1 = time.time()
        if args.contrastive_model == 'moco':
            loss, prob = train_moco(epoch, train_loader, model, model_ema,
                                    contrast, criterion, optimizer, args)
        elif args.contrastive_model == 'simclr':
            print("Train using simclr")
            loss, prob = train_simclr(epoch, train_loader, model, criterion,
                                      optimizer, args)
        else:
            print("Train using InsDis")
            loss, prob = train_ins(epoch, train_loader, model, contrast,
                                   criterion, optimizer, args)
        time2 = time.time()
        print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))

        # tensorboard logger
        logger.log_value('ins_loss', loss, epoch)
        logger.log_value('ins_prob', prob, epoch)
        logger.log_value('learning_rate', optimizer.param_groups[0]['lr'],
                         epoch)

        test_epoch = 2
        if epoch % test_epoch == 0:
            model.eval()

            if args.contrastive_model == 'moco':
                model_ema.eval()

            print('----------Evaluation---------')
            start = time.time()

            if args.dataset == 'cifar':
                acc = kNN(epoch,
                          model,
                          train_loader,
                          test_loader,
                          200,
                          args.nce_t,
                          n_data,
                          low_dim=128,
                          memory_bank=None)

            print("Evaluation Time: '{}'s".format(time.time() - start))
            # writer.add_scalar('nn_acc', acc, epoch)
            logger.log_value('Test accuracy', acc, epoch)

            # print('accuracy: {}% \t (best acc: {}%)'.format(acc, best_acc))
            print('[Epoch]: {}'.format(epoch))
            print('accuracy: {}%)'.format(acc))
            # test_log_file.flush()

        # save model
        if epoch % args.save_freq == 0:
            print('==> Saving...')
            state = {
                'opt': args,
                'model': model.state_dict(),
                # 'contrast': contrast.state_dict(),
                'optimizer': optimizer.state_dict(),
                'epoch': epoch,
            }
            if args.contrastive_model == 'moco':
                state['model_ema'] = model_ema.state_dict()
            if args.amp:
                state['amp'] = amp.state_dict()
            save_file = os.path.join(
                args.model_folder,
                'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
            torch.save(state, save_file)
            # help release GPU memory
            del state

        # saving the model
        print('==> Saving...')
        state = {
            'opt': args,
            'model': model.state_dict(),
            # 'contrast': contrast.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch': epoch,
        }
        if args.contrastive_model == 'moco':
            state['model_ema'] = model_ema.state_dict()
        if args.amp:
            state['amp'] = amp.state_dict()
        save_file = os.path.join(args.model_folder, 'current.pth')
        torch.save(state, save_file)
        if epoch % args.save_freq == 0:
            save_file = os.path.join(
                args.model_folder,
                'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
            torch.save(state, save_file)
        # help release GPU memory
        del state
        torch.cuda.empty_cache()
    def run_main(self):
        """"""
        """Fix the random seed"""
        np.random.seed(self.seed)
        torch.manual_seed(self.seed)
        """RandAugment parameters"""
        if self.flag_randaug == 1:
            if self.rand_n == 0 and self.rand_m == 0:
                if self.n_model == 'ResNet':
                    self.rand_n = 2
                    self.rand_m = 9
                elif self.n_model == 'WideResNet':
                    if self.n_data == 'CIFAR-10':
                        self.rand_n = 3
                        self.rand_m = 5
                    elif self.n_data == 'CIFAR-100':
                        self.rand_n = 2
                        self.rand_m = 14
                    elif self.n_data == 'SVHN':
                        self.rand_n = 3
                        self.rand_m = 7
        """Dataset"""
        traintest_dataset = dataset.MyDataset_training(
            n_data=self.n_data,
            num_data=self.num_training_data,
            seed=self.seed,
            flag_randaug=self.flag_randaug,
            rand_n=self.rand_n,
            rand_m=self.rand_m,
            cutout=self.cutout)
        self.num_channel, self.num_classes, self.size_after_cnn, self.input_size, self.hidden_size = traintest_dataset.get_info(
            n_data=self.n_data)

        n_samples = len(traintest_dataset)
        if self.num_training_data == 0:
            self.num_training_data = n_samples

        if self.flag_traintest == 1:
            # train_size = self.num_classes * 100
            train_size = int(n_samples * 0.65)
            test_size = n_samples - train_size
            train_dataset, test_dataset = torch.utils.data.random_split(
                traintest_dataset, [train_size, test_size])

            train_sampler = None
            test_sampler = None
        else:
            train_dataset = traintest_dataset
            test_dataset = dataset.MyDataset_test(n_data=self.n_data)

            train_sampler = train_dataset.sampler
            test_sampler = test_dataset.sampler

        num_workers = 16
        train_shuffle = True
        test_shuffle = False

        if train_sampler:
            train_shuffle = False
        if test_sampler:
            test_shuffle = False
        self.train_loader = torch.utils.data.DataLoader(
            dataset=train_dataset,
            batch_size=self.batch_size_training,
            sampler=train_sampler,
            shuffle=train_shuffle,
            num_workers=num_workers,
            pin_memory=True)
        self.test_loader = torch.utils.data.DataLoader(
            dataset=test_dataset,
            batch_size=self.batch_size_test,
            sampler=test_sampler,
            shuffle=test_shuffle,
            num_workers=num_workers,
            pin_memory=True)
        """Transfer learning"""
        if self.flag_transfer == 1:
            pretrained = True
            num_classes = 1000
        else:
            pretrained = False
            num_classes = self.num_classes
        """Neural network model"""
        model = None
        if self.n_model == 'CNN':
            model = cnn.ConvNet(num_classes=self.num_classes,
                                num_channel=self.num_channel,
                                size_after_cnn=self.size_after_cnn,
                                n_aug=self.n_aug)
        elif self.n_model == 'ResNet':
            model = resnet.ResNet(n_data=self.n_data,
                                  depth=50,
                                  num_classes=self.num_classes,
                                  num_channel=self.num_channel,
                                  n_aug=self.n_aug,
                                  bottleneck=True)  # resnet50
            # model = resnet.ResNet(n_data=self.n_data, depth=200, num_classes=self.num_classes, num_channel=self.num_channel, bottleneck=True)  # resnet200
        elif self.n_model == 'WideResNet':
            # model = resnet.wide_resnet50_2(num_classes=self.num_classes, num_channel=self.num_channel)
            # model = WideResNet(depth=40, iden_factor=2, dropout_rate=0.0, num_classes=num_class, num_channel=self.num_channel)  # wresnet40_2
            model = wideresnet.WideResNet(depth=28,
                                          widen_factor=10,
                                          dropout_rate=0.0,
                                          num_classes=self.num_classes,
                                          num_channel=self.num_channel,
                                          n_aug=self.n_aug)  # wresnet28_10
        print(torch.__version__)
        """Transfer learning"""
        if self.flag_transfer == 1:
            for param in model.parameters():
                param.requires_grad = False

            num_features = model.fc.in_features
            model.fc = nn.Linear(num_features, self.num_classes)
        """Show paramters"""
        if self.show_params == 1:
            params = 0
            for p in model.parameters():
                if p.requires_grad:
                    params += p.numel()
            print(params)
        """GPU setting"""
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        model = model.to(device)
        if device == 'cuda':
            if self.gpu_multi == 1:
                model = torch.nn.DataParallel(model)
            torch.backends.cudnn.benchmark = True
            print('GPU={}'.format(torch.cuda.device_count()))
        """Loss function"""
        if self.lb_smooth > 0.0:
            criterion = objective.SmoothCrossEntropyLoss(self.lb_smooth)
        else:
            criterion = objective.SoftCrossEntropy()
        """Optimizer"""
        optimizer = 0

        if self.opt == 0:  # Adam
            if self.flag_transfer == 1:
                optimizer = torch.optim.Adam(model.fc.parameters(), lr=0.001)
            else:
                optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
        elif self.opt == 1:  # SGD
            if self.n_model == 'ResNet':
                lr = 0.1
                weight_decay = 0.0001
            elif self.n_model == 'WideResNet':
                if self.n_data == 'SVHN':
                    lr = 0.005
                    weight_decay = 0.001
                else:
                    lr = 0.1
                    weight_decay = 0.0005
            else:
                lr = 0.1
                weight_decay = 0.0005

            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=lr,
                                        momentum=0.9,
                                        weight_decay=weight_decay,
                                        nesterov=True)
        if self.flag_lars == 1:
            from torchlars import LARS
            optimizer = LARS(optimizer)
        """Learning rate scheduling"""
        scheduler = None
        if self.flag_lr_schedule == 2:
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer, T_max=self.num_epochs, eta_min=0.)
        elif self.flag_lr_schedule == 3:
            if self.num_epochs == 90:
                scheduler = torch.optim.lr_scheduler.MultiStepLR(
                    optimizer, [30, 60, 80])
            elif self.num_epochs == 180:
                scheduler = torch.optim.lr_scheduler.MultiStepLR(
                    optimizer, [60, 120, 160])
            elif self.num_epochs == 270:
                scheduler = torch.optim.lr_scheduler.MultiStepLR(
                    optimizer, [90, 180, 240])

        if self.flag_warmup == 1:
            if self.n_model == 'ResNet':
                multiplier = 2
                total_epoch = 3
            elif self.n_model == 'WideResNet':
                multiplier = 2
                if self.n_data == 'SVHN':
                    total_epoch = 3
                else:
                    total_epoch = 5
            else:
                multiplier = 2
                total_epoch = 3

            scheduler = GradualWarmupScheduler(optimizer,
                                               multiplier=multiplier,
                                               total_epoch=total_epoch,
                                               after_scheduler=scheduler)
        """Initialize"""
        self.flag_noise = np.random.randint(0, 2, self.num_training_data)

        if self.flag_acc5 == 1:
            results = np.zeros((self.num_epochs, 6))
        else:
            results = np.zeros((self.num_epochs, 5))
        start_time = timeit.default_timer()

        t = 0
        # fixed_interval = 10
        fixed_interval = 1
        loss_fixed_all = np.zeros(self.num_epochs // fixed_interval)
        self.loss_training_batch = np.zeros(
            int(self.num_epochs *
                np.ceil(self.num_training_data / self.batch_size_training)))

        for epoch in range(self.num_epochs):
            """Training"""
            model.train()
            start_epoch_time = timeit.default_timer(
            )  # Get the start time of this epoch

            loss_each_all = np.zeros(self.num_training_data)
            loss_training_all = 0
            loss_test_all = 0
            """Learning rate scheduling"""
            # if self.flag_lr_schedule == 1:
            #     if self.num_epochs == 200:
            #         if epoch == 100:
            #             optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

            if self.flag_variance == 1:  # when computing the variance of training loss
                if epoch % fixed_interval == 0:
                    loss_fixed = np.zeros(self.num_training_data //
                                          self.batch_size_training)

                    for i, (images, labels,
                            index) in enumerate(self.train_loader):
                        if np.array(images.data.cpu()).ndim == 3:
                            images = images.reshape(images.shape[0], 1,
                                                    images.shape[1],
                                                    images.shape[2]).to(device)
                        else:
                            images = images.to(device)
                        labels = labels.to(device)

                        flag_onehot = 0
                        if self.flag_spa == 1:
                            outputs_fixed = model.forward(images)
                            if flag_onehot == 0:
                                labels = np.identity(self.num_classes)[labels]
                        else:
                            outputs_fixed = model(images)
                            labels = np.identity(self.num_classes)[np.array(
                                labels.data.cpu())]
                        labels = util.to_var(torch.from_numpy(labels).float())

                        loss_fixed[i] = criterion.forward(
                            outputs_fixed, labels)
                    loss_fixed_all[t] = np.var(loss_fixed)
                    t += 1

            total_steps = len(self.train_loader)
            steps = 0
            num_training_data = 0
            for i, (images, labels, index) in enumerate(self.train_loader):
                steps += 1
                if np.array(images.data.cpu()).ndim == 3:
                    images = images.reshape(images.shape[0], 1,
                                            images.shape[1],
                                            images.shape[2]).to(device)
                else:
                    images = images.to(device)
                labels = labels.to(device)
                """Save images"""
                if self.save_images == 1:
                    util.save_images(images)
                """Get training loss for each sample by inputting data before applying data augmentation"""
                if self.flag_spa == 1:
                    outputs = model(images)
                    labels_spa = labels.clone()

                    if labels_spa.ndim == 1:
                        labels_spa = torch.eye(
                            self.num_classes,
                            device='cuda')[labels_spa].clone()  # To one-hot

                    loss_each = criterion.forward_each_example(
                        outputs, labels_spa)  # Loss for each sample
                    loss_each_all[index] = np.array(
                        loss_each.data.cpu())  # Put losses for target samples

                    self.flag_noise = util.flag_update(
                        loss_each_all,
                        self.judge_noise)  # Update the noise flag
                """Forward propagation"""
                if self.flag_spa == 1:
                    images, labels = util.self_paced_augmentation(
                        images=images,
                        labels=labels,
                        flag_noise=self.flag_noise,
                        index=np.array(index.data.cpu()),
                        n_aug=self.n_aug,
                        num_classes=self.num_classes)
                else:
                    images, labels = util.run_n_aug(
                        x=images,
                        y=labels,
                        n_aug=self.n_aug,
                        num_classes=self.num_classes)

                outputs = model(images)

                if labels.ndim == 1:
                    labels = torch.eye(self.num_classes,
                                       device='cuda')[labels].clone()

                loss_training = criterion.forward(outputs, labels)
                loss_training_all += loss_training.item() * outputs.shape[0]
                # self.loss_training_batch[int(i + epoch * np.ceil(self.num_training_data / self.batch_size_training))] = loss_training * outputs.shape[0]
                num_training_data += images.shape[0]
                """Back propagation and update"""
                optimizer.zero_grad()
                loss_training.backward()
                optimizer.step()
                """When changing flag_noise randomly"""
                """
                if self.flag_spa == 1:
                    outputs = model.forward(x=images)

                    if labels.ndim == 1:
                        y_soft = torch.eye(self.num_classes, device='cuda')[labels]  # Convert to one-hot
                    loss_each = criterion.forward_each_example(outputs, y_soft)  # Loss for each sample
                    loss_each_all[index] = np.array(loss_each.data.cpu())  # Put losses for target samples
    
                    self.flag_noise = util.flag_update(loss_each_all, self.judge_noise)  # Update the noise flag
                """

            loss_training_each = loss_training_all / num_training_data
            # np.random.shuffle(self.flag_noise)
            """Test"""
            model.eval()

            with torch.no_grad():
                if self.flag_acc5 == 1:
                    top1 = list()
                    top5 = list()
                else:
                    correct = 0
                    total = 0

                num_test_data = 0
                for images, labels in self.test_loader:
                    if np.array(images.data).ndim == 3:
                        images = images.reshape(images.shape[0], 1,
                                                images.shape[1],
                                                images.shape[2]).to(device)
                    else:
                        images = images.to(device)
                    labels = labels.to(device)

                    outputs = model(x=images)

                    if self.flag_acc5 == 1:
                        acc1, acc5 = util.accuracy(outputs.data,
                                                   labels.long(),
                                                   topk=(1, 5))
                        top1.append(acc1[0].item())
                        top5.append(acc5[0].item())
                    else:
                        _, predicted = torch.max(outputs.data, 1)
                        correct += (predicted == labels.long()).sum().item()
                        total += labels.size(0)

                    if labels.ndim == 1:
                        labels = torch.eye(self.num_classes,
                                           device='cuda')[labels]

                    loss_test = criterion.forward(outputs, labels)
                    loss_test_all += loss_test.item() * outputs.shape[0]
                    num_test_data += images.shape[0]
            """Compute test results"""
            top1_avg = 0
            top5_avg = 0
            test_accuracy = 0

            if self.flag_acc5 == 1:
                top1_avg = sum(top1) / float(len(top1))
                top5_avg = sum(top5) / float(len(top5))
            else:
                test_accuracy = 100.0 * correct / total

            loss_test_each = loss_test_all / num_test_data
            """Compute running time"""
            end_epoch_time = timeit.default_timer()
            epoch_time = end_epoch_time - start_epoch_time
            num_flag = np.sum(self.flag_noise == 1)
            """Learning rate scheduling"""
            if self.flag_lr_schedule > 1 and scheduler is not None:
                scheduler.step(epoch - 1 + float(steps) / total_steps)
            """Show results"""
            flag_log = 1
            if flag_log == 1:
                if self.flag_acc5 == 1:
                    print(
                        'Epoch [{}/{}], Train Loss: {:.4f}, Top1 Test Acc: {:.3f} %, Top5 Test Acc: {:.3f} %, Test Loss: {:.4f}, Epoch Time: {:.2f}s, Num_flag: {}'
                        .format(epoch + 1, self.num_epochs, loss_training_each,
                                top1_avg, top5_avg, loss_test_each, epoch_time,
                                num_flag))
                else:
                    print(
                        'Epoch [{}/{}], Train Loss: {:.4f}, Test Acc: {:.3f} %, Test Loss: {:.4f}, Epoch Time: {:.2f}s, Num_flag: {}'
                        .format(epoch + 1, self.num_epochs, loss_training_each,
                                test_accuracy, loss_test_each, epoch_time,
                                num_flag))

            if self.flag_wandb == 1:
                wandb.log({"loss_training_each": loss_training_each})
                wandb.log({"test_accuracy": test_accuracy})
                wandb.log({"loss_test_each": loss_test_each})
                wandb.log({"num_flag": num_flag})
                wandb.log({"epoch_time": epoch_time})

            if self.save_file == 1:
                if flag_log == 1:
                    if self.flag_acc5 == 1:
                        results[epoch][0] = loss_training_each
                        results[epoch][1] = top1_avg
                        results[epoch][2] = top5_avg
                        results[epoch][3] = loss_test_each
                        results[epoch][4] = num_flag
                        results[epoch][5] = epoch_time
                    else:
                        results[epoch][0] = loss_training_each
                        results[epoch][1] = test_accuracy
                        results[epoch][2] = loss_test_each
                        results[epoch][3] = num_flag
                        results[epoch][4] = epoch_time

        end_time = timeit.default_timer()

        flag_log = 1
        if flag_log == 1:
            print(' ran for %.4fm' % ((end_time - start_time) / 60.))
        """Show accuracy"""
        top1_avg_max = np.max(results[:, 1])
        print(top1_avg_max)

        if flag_log == 1 and self.flag_acc5 == 1:
            top5_avg_max = np.max(results[:, 2])
            print(top5_avg_max)
        """Save files"""
        if self.save_file == 1:
            if self.flag_randaug == 1:
                np.savetxt(
                    'results/data_%s_model_%s_num_%s_randaug_%s_n_%s_m_%s_seed_%s_acc_%s.csv'
                    % (self.n_data, self.n_model, self.num_training_data,
                       self.flag_randaug, self.rand_n, self.rand_m, self.seed,
                       top1_avg_max),
                    results,
                    delimiter=',')
            else:
                if self.flag_spa == 1:
                    np.savetxt(
                        'results/data_%s_model_%s_judge_%s_aug_%s_num_%s_seed_%s_acc_%s.csv'
                        % (self.n_data, self.n_model, self.judge_noise,
                           self.n_aug, self.num_training_data, self.seed,
                           top1_avg_max),
                        results,
                        delimiter=',')
                else:
                    np.savetxt(
                        'results/data_%s_model_%s_aug_%s_num_%s_seed_%s_acc_%s.csv'
                        % (self.n_data, self.n_model, self.n_aug,
                           self.num_training_data, self.seed, top1_avg_max),
                        results,
                        delimiter=',')

            if self.flag_variance == 1:
                np.savetxt('results/loss_variance_judge_%s_aug_%s_acc_%s.csv' %
                           (self.judge_noise, self.n_aug, top1_avg_max),
                           loss_fixed_all,
                           delimiter=',')
Example #7
0
    def train(self,
              dataloader,
              temperature,
              ckpt_path,
              n_epochs=90,
              save_size=10):

        # trainers
        criterion = NTCrossEntropyLoss(temperature, self.batch_size,
                                       self.device).to(self.device)
        optimizer = LARS(torch.optim.SGD(self.model.parameters(), lr=4))
        # optimizer = optimizer.to(self.device)

        losses = []

        for epoch in range(n_epochs):
            with tqdm(total=len(dataloader)) as progress:
                running_loss = 0
                i = 0
                for (xis, xjs), _ in dataloader:
                    i += 1
                    optimizer.zero_grad()

                    xis = xis.to(self.device)
                    xjs = xjs.to(self.device)

                    # Get representations and projections
                    his, zis = self.model(xis)
                    hjs, zjs = self.model(xjs)

                    # normalize
                    zis = F.normalize(zis, dim=1)
                    zjs = F.normalize(zjs, dim=1)

                    loss = criterion(zis, zjs)
                    running_loss += loss.item()

                    # optimize
                    loss.backward()
                    optimizer.step()

                    # update tqdm
                    progress.set_description('train loss:{:.4f}'.format(
                        loss.item()))
                    progress.update()

                    # record loss
                    if i % save_size == (save_size - 1):
                        losses.append(running_loss / save_size)
                        running_loss = 0

                # save model
                if epoch % 10 == 0:
                    torch.save(
                        {
                            'epoch': epoch,
                            'model_state_dict': self.model.state_dict(),
                            'optimizer_state_dict': optimizer.state_dict(),
                            'loss': losses,
                        }, ckpt_path + f'{epoch}')

        return self.return_model(), losses
Example #8
0
    return v_loss


model = resnet50_cifar(args.feature_size).type(dtype)

if args.multi_gpu:
    model = torch.nn.DataParallel(model, device_ids=[4, 5, 6, 7])
    print('Multi gpu')

# init?

base_optimizer = optim.Adam(model.parameters(),
                            lr=args.lr,
                            weight_decay=args.decay_lr)
optimizer = LARS(optimizer=base_optimizer, eps=1e-8, trust_coef=0.001)
scheduler = ExponentialLR(optimizer, gamma=args.decay_lr)

# Main training loop
best_loss = np.inf

# Resume training
if args.load_model is not None:
    if os.path.isfile(args.load_model):
        checkpoint = torch.load(args.load_model)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        base_optimizer.load_state_dict(checkpoint['base_optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        best_loss = checkpoint['val_loss']
        epoch = checkpoint['epoch']
def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu
    
    if args.distributed:
        if args.dist_url == 'env://' and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank)
    
    model = PixPro(
                encoder=resnet50, 
                dim1 = args.pcl_dim_1, 
                dim2 = args.pcl_dim_2, 
                momentum = args.encoder_momentum,
                threshold = args.threshold,
                temperature = args.T,
                sharpness = args.sharpness ,
                num_linear = args.num_linear,
                )

    args.lr = args.lr_base * args.batch_size/256
    
    if args.distributed:
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)

            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int((args.workers + ngpus_per_node -1) / ngpus_per_node)
            
            # convert batch norm --> sync batch norm
            sync_bn_model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])

        else:
            model.cuda()
            model = torch.nn.parallel.DistributedDataParallel(moel)

    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        raise NotImplementedError('only DDP is supported.')

    base_optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    optimizer = LARS(optimizer=base_optimizer, eps=1e-8)
    writer = SummaryWriter(args.log_dir) 
    if args.resume:
        checkpoint = torch.load(args.resume)
        args.start_epoch = checkpoint['epoch']

        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])

    cudnn.benchmark = True

    dataset = PixProDataset(root=args.train_path, args=args)
    
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
    else:
        train_sampler = None
    
    loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
                    num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True)
    
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        
        adjust_lr(optimizer, epoch, args)
        train(args, epoch, loader, model, optimizer, writer)


        if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0):
            save_name = '{}.pth.tar'.format(epoch)
            save_name = os.path.join(args.checkpoint_dir, save_name)
            torch.save({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer' : optimizer.state_dict(),
                }, save_name)
Example #10
0
def prepare_train_eval(rank, world_size, run_name, train_config, model_config,
                       hdf5_path_train):
    cfgs = dict2clsattr(train_config, model_config)
    prev_ada_p, step, best_step, best_fid, best_fid_checkpoint_path, mu, sigma, inception_model = None, 0, 0, None, None, None, None, None
    if cfgs.distributed_data_parallel:
        print("Use GPU: {} for training.".format(rank))
        setup(rank, world_size)
        torch.cuda.set_device(rank)

    writer = SummaryWriter(
        log_dir=join('./logs', run_name)) if rank == 0 else None
    if rank == 0:
        logger = make_logger(run_name, None)
        logger.info('Run name : {run_name}'.format(run_name=run_name))
        logger.info(train_config)
        logger.info(model_config)
    else:
        logger = None

    ##### load dataset #####
    if rank == 0: logger.info('Load train datasets...')
    train_dataset = LoadDataset(cfgs.dataset_name,
                                cfgs.data_path,
                                train=True,
                                download=True,
                                resize_size=cfgs.img_size,
                                hdf5_path=hdf5_path_train,
                                random_flip=cfgs.random_flip_preprocessing)
    if cfgs.reduce_train_dataset < 1.0:
        num_train = int(cfgs.reduce_train_dataset * len(train_dataset))
        train_dataset, _ = torch.utils.data.random_split(
            train_dataset,
            [num_train, len(train_dataset) - num_train])
    if rank == 0:
        logger.info('Train dataset size : {dataset_size}'.format(
            dataset_size=len(train_dataset)))

    if rank == 0:
        logger.info('Load {mode} datasets...'.format(mode=cfgs.eval_type))
    eval_mode = True if cfgs.eval_type == 'train' else False
    eval_dataset = LoadDataset(cfgs.dataset_name,
                               cfgs.data_path,
                               train=eval_mode,
                               download=True,
                               resize_size=cfgs.img_size,
                               hdf5_path=None,
                               random_flip=False)
    if rank == 0:
        logger.info('Eval dataset size : {dataset_size}'.format(
            dataset_size=len(eval_dataset)))

    if cfgs.distributed_data_parallel:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
        cfgs.batch_size = cfgs.batch_size // world_size
    else:
        train_sampler = None

    train_dataloader = DataLoader(train_dataset,
                                  batch_size=cfgs.batch_size,
                                  shuffle=(train_sampler is None),
                                  pin_memory=True,
                                  num_workers=cfgs.num_workers,
                                  sampler=train_sampler,
                                  drop_last=True)
    eval_dataloader = DataLoader(eval_dataset,
                                 batch_size=cfgs.batch_size,
                                 shuffle=False,
                                 pin_memory=True,
                                 num_workers=cfgs.num_workers,
                                 drop_last=False)

    ##### build model #####
    if rank == 0: logger.info('Build model...')
    module = __import__(
        'models.{architecture}'.format(architecture=cfgs.architecture),
        fromlist=['something'])
    if rank == 0:
        logger.info('Modules are located on models.{architecture}.'.format(
            architecture=cfgs.architecture))
    Gen = module.Generator(cfgs.z_dim, cfgs.shared_dim, cfgs.img_size,
                           cfgs.g_conv_dim, cfgs.g_spectral_norm,
                           cfgs.attention, cfgs.attention_after_nth_gen_block,
                           cfgs.activation_fn, cfgs.conditional_strategy,
                           cfgs.num_classes, cfgs.g_init, cfgs.G_depth,
                           cfgs.mixed_precision).to(rank)

    Dis = module.Discriminator(
        cfgs.img_size, cfgs.d_conv_dim, cfgs.d_spectral_norm, cfgs.attention,
        cfgs.attention_after_nth_dis_block, cfgs.activation_fn,
        cfgs.conditional_strategy, cfgs.hypersphere_dim, cfgs.num_classes,
        cfgs.nonlinear_embed, cfgs.normalize_embed, cfgs.d_init, cfgs.D_depth,
        cfgs.mixed_precision).to(rank)

    if cfgs.ema:
        if rank == 0:
            logger.info('Prepare EMA for G with decay of {}.'.format(
                cfgs.ema_decay))
        Gen_copy = module.Generator(
            cfgs.z_dim,
            cfgs.shared_dim,
            cfgs.img_size,
            cfgs.g_conv_dim,
            cfgs.g_spectral_norm,
            cfgs.attention,
            cfgs.attention_after_nth_gen_block,
            cfgs.activation_fn,
            cfgs.conditional_strategy,
            cfgs.num_classes,
            initialize=False,
            G_depth=cfgs.G_depth,
            mixed_precision=cfgs.mixed_precision).to(rank)
        Gen_ema = ema(Gen, Gen_copy, cfgs.ema_decay, cfgs.ema_start)
    else:
        Gen_copy, Gen_ema = None, None

    if rank == 0: logger.info(count_parameters(Gen))
    if rank == 0: logger.info(Gen)

    if rank == 0: logger.info(count_parameters(Dis))
    if rank == 0: logger.info(Dis)

    ### define loss functions and optimizers
    G_loss = {
        'vanilla': loss_dcgan_gen,
        'least_square': loss_lsgan_gen,
        'hinge': loss_hinge_gen,
        'wasserstein': loss_wgan_gen
    }
    D_loss = {
        'vanilla': loss_dcgan_dis,
        'least_square': loss_lsgan_dis,
        'hinge': loss_hinge_dis,
        'wasserstein': loss_wgan_dis
    }

    if cfgs.optimizer == "SGD":
        G_optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                             Gen.parameters()),
                                      cfgs.g_lr,
                                      momentum=cfgs.momentum,
                                      nesterov=cfgs.nesterov)
        D_optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                             Dis.parameters()),
                                      cfgs.d_lr,
                                      momentum=cfgs.momentum,
                                      nesterov=cfgs.nesterov)
    elif cfgs.optimizer == "RMSprop":
        G_optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad,
                                                 Gen.parameters()),
                                          cfgs.g_lr,
                                          momentum=cfgs.momentum,
                                          alpha=cfgs.alpha)
        D_optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad,
                                                 Dis.parameters()),
                                          cfgs.d_lr,
                                          momentum=cfgs.momentum,
                                          alpha=cfgs.alpha)
    elif cfgs.optimizer == "Adam":
        G_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                              Gen.parameters()),
                                       cfgs.g_lr, [cfgs.beta1, cfgs.beta2],
                                       eps=1e-6)
        D_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                              Dis.parameters()),
                                       cfgs.d_lr, [cfgs.beta1, cfgs.beta2],
                                       eps=1e-6)
    else:
        raise NotImplementedError

    if cfgs.LARS_optimizer:
        G_optimizer = LARS(optimizer=G_optimizer, eps=1e-8, trust_coef=0.001)
        D_optimizer = LARS(optimizer=D_optimizer, eps=1e-8, trust_coef=0.001)

    ##### load checkpoints if needed #####
    if cfgs.checkpoint_folder is None:
        checkpoint_dir = make_checkpoint_dir(cfgs.checkpoint_folder, run_name)
    else:
        when = "current" if cfgs.load_current is True else "best"
        if not exists(abspath(cfgs.checkpoint_folder)):
            raise NotADirectoryError
        checkpoint_dir = make_checkpoint_dir(cfgs.checkpoint_folder, run_name)
        g_checkpoint_dir = glob.glob(
            join(checkpoint_dir,
                 "model=G-{when}-weights-step*.pth".format(when=when)))[0]
        d_checkpoint_dir = glob.glob(
            join(checkpoint_dir,
                 "model=D-{when}-weights-step*.pth".format(when=when)))[0]
        Gen, G_optimizer, trained_seed, run_name, step, prev_ada_p = load_checkpoint(
            Gen, G_optimizer, g_checkpoint_dir)
        Dis, D_optimizer, trained_seed, run_name, step, prev_ada_p, best_step, best_fid, best_fid_checkpoint_path =\
            load_checkpoint(Dis, D_optimizer, d_checkpoint_dir, metric=True)
        if rank == 0: logger = make_logger(run_name, None)
        if cfgs.ema:
            g_ema_checkpoint_dir = glob.glob(
                join(checkpoint_dir,
                     "model=G_ema-{when}-weights-step*.pth".format(
                         when=when)))[0]
            Gen_copy = load_checkpoint(Gen_copy,
                                       None,
                                       g_ema_checkpoint_dir,
                                       ema=True)
            Gen_ema.source, Gen_ema.target = Gen, Gen_copy

        writer = SummaryWriter(
            log_dir=join('./logs', run_name)) if rank == 0 else None
        if cfgs.train_configs['train']:
            assert cfgs.seed == trained_seed, "Seed for sampling random numbers should be same!"

        if rank == 0:
            logger.info('Generator checkpoint is {}'.format(g_checkpoint_dir))
        if rank == 0:
            logger.info(
                'Discriminator checkpoint is {}'.format(d_checkpoint_dir))
        if cfgs.freeze_layers > -1:
            prev_ada_p, step, best_step, best_fid, best_fid_checkpoint_path = None, 0, 0, None, None

    ##### wrap models with DP and convert BN to Sync BN #####
    if world_size > 1:
        if cfgs.distributed_data_parallel:
            if cfgs.synchronized_bn:
                process_group = torch.distributed.new_group(
                    [w for w in range(world_size)])
                Gen = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                    Gen, process_group)
                Dis = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                    Dis, process_group)
                if cfgs.ema:
                    Gen_copy = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                        Gen_copy, process_group)

            Gen = DDP(Gen,
                      device_ids=[rank],
                      broadcast_buffers=False,
                      find_unused_parameters=True)
            Dis = DDP(Dis,
                      device_ids=[rank],
                      broadcast_buffers=False,
                      find_unused_parameters=True)
            if cfgs.ema:
                Gen_copy = DDP(Gen_copy,
                               device_ids=[rank],
                               broadcast_buffers=False,
                               find_unused_parameters=True)
        else:
            Gen = DataParallel(Gen, output_device=rank)
            Dis = DataParallel(Dis, output_device=rank)
            if cfgs.ema:
                Gen_copy = DataParallel(Gen_copy, output_device=rank)

            if cfgs.synchronized_bn:
                Gen = convert_model(Gen).to(rank)
                Dis = convert_model(Dis).to(rank)
                if cfgs.ema:
                    Gen_copy = convert_model(Gen_copy).to(rank)

    ##### load the inception network and prepare first/secend moments for calculating FID #####
    if cfgs.eval:
        inception_model = InceptionV3().to(rank)
        if world_size > 1 and cfgs.distributed_data_parallel:
            toggle_grad(inception_model, on=True)
            inception_model = DDP(inception_model,
                                  device_ids=[rank],
                                  broadcast_buffers=False,
                                  find_unused_parameters=True)
        elif world_size > 1 and cfgs.distributed_data_parallel is False:
            inception_model = DataParallel(inception_model, output_device=rank)
        else:
            pass

        mu, sigma = prepare_inception_moments(dataloader=eval_dataloader,
                                              generator=Gen,
                                              eval_mode=cfgs.eval_type,
                                              inception_model=inception_model,
                                              splits=1,
                                              run_name=run_name,
                                              logger=logger,
                                              device=rank)

    worker = make_worker(
        cfgs=cfgs,
        run_name=run_name,
        best_step=best_step,
        logger=logger,
        writer=writer,
        n_gpus=world_size,
        gen_model=Gen,
        dis_model=Dis,
        inception_model=inception_model,
        Gen_copy=Gen_copy,
        Gen_ema=Gen_ema,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        train_dataloader=train_dataloader,
        eval_dataloader=eval_dataloader,
        G_optimizer=G_optimizer,
        D_optimizer=D_optimizer,
        G_loss=G_loss[cfgs.adv_loss],
        D_loss=D_loss[cfgs.adv_loss],
        prev_ada_p=prev_ada_p,
        rank=rank,
        checkpoint_dir=checkpoint_dir,
        mu=mu,
        sigma=sigma,
        best_fid=best_fid,
        best_fid_checkpoint_path=best_fid_checkpoint_path,
    )

    if cfgs.train_configs['train']:
        step = worker.train(current_step=step, total_step=cfgs.total_step)

    if cfgs.eval:
        is_save = worker.evaluation(
            step=step,
            standing_statistics=cfgs.standing_statistics,
            standing_step=cfgs.standing_step)

    if cfgs.save_images:
        worker.save_images(is_generate=True,
                           png=True,
                           npz=True,
                           standing_statistics=cfgs.standing_statistics,
                           standing_step=cfgs.standing_step)

    if cfgs.image_visualization:
        worker.run_image_visualization(
            nrow=cfgs.nrow,
            ncol=cfgs.ncol,
            standing_statistics=cfgs.standing_statistics,
            standing_step=cfgs.standing_step)

    if cfgs.k_nearest_neighbor:
        worker.run_nearest_neighbor(
            nrow=cfgs.nrow,
            ncol=cfgs.ncol,
            standing_statistics=cfgs.standing_statistics,
            standing_step=cfgs.standing_step)

    if cfgs.interpolation:
        assert cfgs.architecture in [
            "big_resnet", "biggan_deep"
        ], "StudioGAN does not support interpolation analysis except for biggan and biggan_deep."
        worker.run_linear_interpolation(
            nrow=cfgs.nrow,
            ncol=cfgs.ncol,
            fix_z=True,
            fix_y=False,
            standing_statistics=cfgs.standing_statistics,
            standing_step=cfgs.standing_step)
        worker.run_linear_interpolation(
            nrow=cfgs.nrow,
            ncol=cfgs.ncol,
            fix_z=False,
            fix_y=True,
            standing_statistics=cfgs.standing_statistics,
            standing_step=cfgs.standing_step)

    if cfgs.frequency_analysis:
        worker.run_frequency_analysis(
            num_images=len(train_dataset) // cfgs.num_classes,
            standing_statistics=cfgs.standing_statistics,
            standing_step=cfgs.standing_step)

    if cfgs.tsne_analysis:
        worker.run_tsne(dataloader=eval_dataloader,
                        standing_statistics=cfgs.standing_statistics,
                        standing_step=cfgs.standing_step)
Example #11
0
def train(cfg, writer, logger):
    
    # Setup random seeds to a determinated value for reproduction
    # seed = 1337
    # torch.manual_seed(seed)
    # torch.cuda.manual_seed(seed)
    # np.random.seed(seed)
    # random.seed(seed)
    # np.random.default_rng(seed)

    # Setup Augmentations
    augmentations = cfg.train.augment
    logger.info(f'using augments: {augmentations}')
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg.data.dataloader)
    data_path = cfg.data.path

    logger.info("Using dataset: {}".format(data_path))

    t_loader = data_loader(
        data_path,
        # transform=None,
        # time_shuffle = cfg.data.time_shuffle,
        # to_tensor=False,
        data_format = cfg.data.format,
        split=cfg.data.train_split,
        norm = cfg.data.norm,
        augments=data_aug
        )

    v_loader = data_loader(
        data_path,
        # transform=None,
        # time_shuffle = cfg.data.time_shuffle,
        # to_tensor=False,
        data_format = cfg.data.format,
        split=cfg.data.val_split,
        )
    train_data_len = len(t_loader)
    logger.info(f'num of train samples: {train_data_len} \nnum of val samples: {len(v_loader)}')

    batch_size = cfg.train.batch_size
    epoch = cfg.train.epoch
    train_iter = int(np.ceil(train_data_len / batch_size) * epoch)
    logger.info(f'total train iter: {train_iter}')

    trainloader = data.DataLoader(t_loader,
                                  batch_size=batch_size, 
                                  num_workers=cfg.train.n_workers, 
                                  shuffle=True,
                                  persistent_workers=True,
                                  drop_last=True)

    valloader = data.DataLoader(v_loader, 
                                batch_size=10, 
                                # persis
                                num_workers=cfg.train.n_workers,)

    # Setup Model
    device = f'cuda:{cfg.gpu[0]}'
    model = get_model(cfg.model, 2).to(device)
    input_size = (cfg.model.input_nbr, 512, 512)
    logger.info(f"Using Model: {cfg.model.arch}")
    # logger.info(f'model summary: {summary(model, input_size=(input_size, input_size), is_complex=True)}')
    model = torch.nn.DataParallel(model, device_ids=cfg.gpu)      #自动多卡运行,这个好用
    
    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {k:v for k, v in vars(cfg.train.optimizer).items()
                        if k not in ('name', 'wrap')}
    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))
    if hasattr(cfg.train.optimizer, 'warp') and cfg.train.optimizer.wrap=='lars':
        optimizer = LARS(optimizer=optimizer)
        logger.info(f'warp optimizer with {cfg.train.optimizer.wrap}')
    scheduler = get_scheduler(optimizer, cfg.train.lr)
    loss_fn = get_loss_function(cfg)
    logger.info(f"Using loss ,{str(cfg.train.loss)}")

    # load checkpoints
    val_cls_1_acc = 0
    best_cls_1_acc_now = 0
    best_cls_1_acc_iter_now = 0
    val_macro_OA = 0
    best_macro_OA_now = 0
    best_macro_OA_iter_now = 0
    start_iter = 0
    if cfg.train.resume is not None:
        if os.path.isfile(cfg.train.resume):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(cfg.train.resume)
            )

            # load model state
            checkpoint = torch.load(cfg.train.resume)
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            # best_cls_1_acc_now = checkpoint["best_cls_1_acc_now"]
            # best_cls_1_acc_iter_now = checkpoint["best_cls_1_acc_iter_now"]

            start_iter = checkpoint["epoch"]
            logger.info(
                "Loaded checkpoint '{}' (iter {})".format(
                    cfg.train.resume, checkpoint["epoch"]
                )
            )

            # copy tensorboard files
            resume_src_dir = osp.split(cfg.train.resume)[0]
            # shutil.copytree(resume_src_dir, writer.get_logdir())
            for file in os.listdir(resume_src_dir):
                if not ('.log' in file or '.yml' in file or '_last_model' in file):
                # if 'events.out.tfevents' in file:
                    resume_dst_dir = writer.get_logdir()
                    fu.copy(osp.join(resume_src_dir, file), resume_dst_dir, )

        else:
            logger.info("No checkpoint found at '{}'".format(cfg.train.resume))

    # Setup Metrics
    running_metrics_val = runningScore(2)
    runing_metrics_train = runningScore(2)
    val_loss_meter = averageMeter()
    train_time_meter = averageMeter()

    # train
    it = start_iter
    train_start_time = time.time() 
    train_val_start_time = time.time()
    model.train()   
    while it < train_iter:
        for (file_a, file_b, label, mask) in trainloader:
            it += 1           
            file_a = file_a.to(device)            
            file_b = file_b.to(device)            
            label = label.to(device)            
            mask = mask.to(device)

            optimizer.zero_grad()
            # print(f'dtype: {file_a.dtype}')
            outputs = model(file_a, file_b)
            loss = loss_fn(input=outputs, target=label, mask=mask)
            loss.backward()

            # print('conv11: ', model.conv11.weight.grad, model.conv11.weight.grad.shape)
            # print('conv21: ', model.conv21.weight.grad, model.conv21.weight.grad.shape)
            # print('conv31: ', model.conv31.weight.grad, model.conv31.weight.grad.shape)

            # In PyTorch 1.1.0 and later, you should call `optimizer.step()` before `lr_scheduler.step()`
            optimizer.step()
            scheduler.step()
            
            # record the acc of the minibatch
            pred = outputs.max(1)[1].cpu().numpy()
            runing_metrics_train.update(label.cpu().numpy(), pred, mask.cpu().numpy())

            train_time_meter.update(time.time() - train_start_time)

            if it % cfg.train.print_interval == 0:
                # acc of the samples between print_interval
                score, _ = runing_metrics_train.get_scores()
                train_cls_0_acc, train_cls_1_acc = score['Acc']
                fmt_str = "Iter [{:d}/{:d}]  train Loss: {:.4f}  Time/Image: {:.4f},\n0:{:.4f}\n1:{:.4f}"
                print_str = fmt_str.format(it,
                                           train_iter,
                                           loss.item(),      #extracts the loss’s value as a Python float.
                                           train_time_meter.avg / cfg.train.batch_size,train_cls_0_acc, train_cls_1_acc)
                runing_metrics_train.reset()
                train_time_meter.reset()
                logger.info(print_str)
                writer.add_scalar('loss/train_loss', loss.item(), it)
                writer.add_scalars('metrics/train', {'cls_0':train_cls_0_acc, 'cls_1':train_cls_1_acc}, it)
                # writer.add_scalar('train_metrics/acc/cls_0', train_cls_0_acc, it)
                # writer.add_scalar('train_metrics/acc/cls_1', train_cls_1_acc, it)

            if it % cfg.train.val_interval == 0 or \
               it == train_iter:
                val_start_time = time.time()
                model.eval()            # change behavior like drop out
                with torch.no_grad():   # disable autograd, save memory usage
                    for (file_a_val, file_b_val, label_val, mask_val) in valloader:      
                        file_a_val = file_a_val.to(device)            
                        file_b_val = file_b_val.to(device)

                        outputs = model(file_a_val, file_b_val)
                        # tensor.max() returns the maximum value and its indices
                        pred = outputs.max(1)[1].cpu().numpy()
                        running_metrics_val.update(label_val.numpy(), pred, mask_val.numpy())
            
                        label_val = label_val.to(device)            
                        mask_val = mask_val.to(device)
                        val_loss = loss_fn(input=outputs, target=label_val, mask=mask_val)
                        val_loss_meter.update(val_loss.item())

                score, _ = running_metrics_val.get_scores()
                val_cls_0_acc, val_cls_1_acc = score['Acc']

                writer.add_scalar('loss/val_loss', val_loss_meter.avg, it)
                logger.info(f"Iter [{it}/{train_iter}], val Loss: {val_loss_meter.avg:.4f} Time/Image: {(time.time()-val_start_time)/len(v_loader):.4f}\n0: {val_cls_0_acc:.4f}\n1:{val_cls_1_acc:.4f}")
                # lr_now = optimizer.param_groups[0]['lr']
                # logger.info(f'lr: {lr_now}')
                # writer.add_scalar('lr', lr_now, it+1)

                logger.info('0: {:.4f}\n1:{:.4f}'.format(val_cls_0_acc, val_cls_1_acc))
                writer.add_scalars('metrics/val', {'cls_0':val_cls_0_acc, 'cls_1':val_cls_1_acc}, it)
                # writer.add_scalar('val_metrics/acc/cls_0', val_cls_0_acc, it)
                # writer.add_scalar('val_metrics/acc/cls_1', val_cls_1_acc, it)

                val_loss_meter.reset()
                running_metrics_val.reset()

                # OA=score["Overall_Acc"]
                val_macro_OA = (val_cls_0_acc+val_cls_1_acc)/2
                if val_macro_OA >= best_macro_OA_now and it>200:
                    best_macro_OA_now = val_macro_OA
                    best_macro_OA_iter_now = it
                    state = {
                        "epoch": it,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "scheduler_state": scheduler.state_dict(),
                        "best_macro_OA_now": best_macro_OA_now,
                        'best_macro_OA_iter_now':best_macro_OA_iter_now,
                    }
                    save_path = os.path.join(writer.file_writer.get_logdir(), "{}_{}_best_model.pkl".format(cfg.model.arch,cfg.data.dataloader))
                    torch.save(state, save_path)

                    logger.info("best OA now =  %.8f" % (best_macro_OA_now))
                    logger.info("best OA iter now= %d" % (best_macro_OA_iter_now))

                train_val_time = time.time() - train_val_start_time
                remain_time = train_val_time * (train_iter-it) / it
                m, s = divmod(remain_time, 60)
                h, m = divmod(m, 60)
                if s != 0:
                    train_time = "Remain train time = %d hours %d minutes %d seconds \n" % (h, m, s)
                else:
                    train_time = "Remain train time : train completed.\n"
                logger.info(train_time)
                model.train()

            train_start_time = time.time() 

    logger.info("best OA now =  %.8f" % (best_macro_OA_now))
    logger.info("best OA iter now= %d" % (best_macro_OA_iter_now))

    state = {
            "epoch": it,
            "model_state": model.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "scheduler_state": scheduler.state_dict(),
            "best_macro_OA_now": best_macro_OA_now,
            'best_macro_OA_iter_now':best_macro_OA_iter_now,
            }
    save_path = os.path.join(writer.file_writer.get_logdir(), "{}_{}_last_model.pkl".format(cfg.model.arch, cfg.data.dataloader))
    torch.save(state, save_path)
Example #12
0
def train_and_eval(tag,
                   dataroot,
                   test_ratio=0.0,
                   cv_fold=0,
                   reporter=None,
                   metric='last',
                   save_path=None,
                   only_eval=False):
    if not reporter:
        reporter = lambda **kwargs: 0

    max_epoch = C.get()['epoch']
    trainsampler, trainloader, validloader, testloader_ = get_dataloaders(
        C.get()['dataset'],
        C.get()['batch'],
        dataroot,
        test_ratio,
        split_idx=cv_fold)

    # create a model & an optimizer
    model = get_model(C.get()['model'], num_class(C.get()['dataset']))

    lb_smooth = C.get()['optimizer'].get('label_smoothing', 0.0)
    if lb_smooth > 0.0:
        criterion = SmoothCrossEntropyLoss(lb_smooth)
    else:
        criterion = nn.CrossEntropyLoss()
    if C.get()['optimizer']['type'] == 'sgd':
        optimizer = optim.SGD(model.parameters(),
                              lr=C.get()['lr'],
                              momentum=C.get()['optimizer'].get(
                                  'momentum', 0.9),
                              weight_decay=C.get()['optimizer']['decay'],
                              nesterov=C.get()['optimizer']['nesterov'])
    else:
        raise ValueError('invalid optimizer type=%s' %
                         C.get()['optimizer']['type'])

    if C.get()['optimizer'].get('lars', False):
        from torchlars import LARS
        optimizer = LARS(optimizer)
        logger.info('*** LARS Enabled.')

    lr_scheduler_type = C.get()['lr_schedule'].get('type', 'cosine')
    if lr_scheduler_type == 'cosine':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=C.get()['epoch'], eta_min=0.)
    elif lr_scheduler_type == 'resnet':
        scheduler = adjust_learning_rate_resnet(optimizer)
    else:
        raise ValueError('invalid lr_schduler=%s' % lr_scheduler_type)

    if C.get()['lr_schedule'].get('warmup', None):
        scheduler = GradualWarmupScheduler(
            optimizer,
            multiplier=C.get()['lr_schedule']['warmup']['multiplier'],
            total_epoch=C.get()['lr_schedule']['warmup']['epoch'],
            after_scheduler=scheduler)

    if not tag:
        from RandAugment.metrics import SummaryWriterDummy as SummaryWriter
        logger.warning('tag not provided, no tensorboard log.')
    else:
        from tensorboardX import SummaryWriter
    writers = [
        SummaryWriter(log_dir='./logs/%s/%s' % (tag, x))
        for x in ['train', 'valid', 'test']
    ]

    result = OrderedDict()
    epoch_start = 1
    if save_path and os.path.exists(save_path):
        logger.info('%s file found. loading...' % save_path)
        data = torch.load(save_path)
        if 'model' in data or 'state_dict' in data:
            key = 'model' if 'model' in data else 'state_dict'
            logger.info('checkpoint epoch@%d' % data['epoch'])
            if not isinstance(model, DataParallel):
                model.load_state_dict({
                    k.replace('module.', ''): v
                    for k, v in data[key].items()
                })
            else:
                model.load_state_dict({
                    k if 'module.' in k else 'module.' + k: v
                    for k, v in data[key].items()
                })
            optimizer.load_state_dict(data['optimizer'])
            if data['epoch'] < C.get()['epoch']:
                epoch_start = data['epoch']
            else:
                only_eval = True
        else:
            model.load_state_dict({k: v for k, v in data.items()})
        del data
    else:
        logger.info('"%s" file not found. skip to pretrain weights...' %
                    save_path)
        if only_eval:
            logger.warning(
                'model checkpoint not found. only-evaluation mode is off.')
        only_eval = False

    if only_eval:
        logger.info('evaluation only+')
        model.eval()
        rs = dict()
        rs['train'] = run_epoch(model,
                                trainloader,
                                criterion,
                                None,
                                desc_default='train',
                                epoch=0,
                                writer=writers[0])
        rs['valid'] = run_epoch(model,
                                validloader,
                                criterion,
                                None,
                                desc_default='valid',
                                epoch=0,
                                writer=writers[1])
        rs['test'] = run_epoch(model,
                               testloader_,
                               criterion,
                               None,
                               desc_default='*test',
                               epoch=0,
                               writer=writers[2])
        for key, setname in itertools.product(['loss', 'top1', 'top5'],
                                              ['train', 'valid', 'test']):
            if setname not in rs:
                continue
            result['%s_%s' % (key, setname)] = rs[setname][key]
        result['epoch'] = 0
        return result

    # train loop
    best_top1 = 0
    for epoch in range(epoch_start, max_epoch + 1):
        model.train()
        rs = dict()
        rs['train'] = run_epoch(model,
                                trainloader,
                                criterion,
                                optimizer,
                                desc_default='train',
                                epoch=epoch,
                                writer=writers[0],
                                verbose=True,
                                scheduler=scheduler)
        model.eval()

        if math.isnan(rs['train']['loss']):
            raise Exception('train loss is NaN.')

        if epoch % 5 == 0 or epoch == max_epoch:
            rs['valid'] = run_epoch(model,
                                    validloader,
                                    criterion,
                                    None,
                                    desc_default='valid',
                                    epoch=epoch,
                                    writer=writers[1],
                                    verbose=True)
            rs['test'] = run_epoch(model,
                                   testloader_,
                                   criterion,
                                   None,
                                   desc_default='*test',
                                   epoch=epoch,
                                   writer=writers[2],
                                   verbose=True)

            if metric == 'last' or rs[metric]['top1'] > best_top1:
                if metric != 'last':
                    best_top1 = rs[metric]['top1']
                for key, setname in itertools.product(
                    ['loss', 'top1', 'top5'], ['train', 'valid', 'test']):
                    result['%s_%s' % (key, setname)] = rs[setname][key]
                result['epoch'] = epoch

                writers[1].add_scalar('valid_top1/best', rs['valid']['top1'],
                                      epoch)
                writers[2].add_scalar('test_top1/best', rs['test']['top1'],
                                      epoch)

                reporter(loss_valid=rs['valid']['loss'],
                         top1_valid=rs['valid']['top1'],
                         loss_test=rs['test']['loss'],
                         top1_test=rs['test']['top1'])

                # save checkpoint
                if save_path:
                    logger.info('save model@%d to %s' % (epoch, save_path))
                    torch.save(
                        {
                            'epoch': epoch,
                            'log': {
                                'train': rs['train'].get_dict(),
                                'valid': rs['valid'].get_dict(),
                                'test': rs['test'].get_dict(),
                            },
                            'optimizer': optimizer.state_dict(),
                            'model': model.state_dict()
                        }, save_path)
                    torch.save(
                        {
                            'epoch': epoch,
                            'log': {
                                'train': rs['train'].get_dict(),
                                'valid': rs['valid'].get_dict(),
                                'test': rs['test'].get_dict(),
                            },
                            'optimizer': optimizer.state_dict(),
                            'model': model.state_dict()
                        },
                        save_path.replace(
                            '.pth', '_e%d_top1_%.3f_%.3f' %
                            (epoch, rs['train']['top1'], rs['test']['top1']) +
                            '.pth'))

    del model

    result['top1_test'] = best_top1
    return result
Example #13
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size',
                        type=int,
                        default=4096,
                        metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--optimizer',
                        type=str,
                        default='lamb',
                        choices=['lamb', 'adam', 'lars', 'sgd'],
                        help='which optimizer to use')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=2048,
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=100,
                        metavar='N',
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        metavar='LR',
                        help='learning rate (default: 0.0025)')
    parser.add_argument('--wd',
                        type=float,
                        default=0.01,
                        metavar='WD',
                        help='weight decay (default: 0.01)')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--eta',
                        type=int,
                        default=0.001,
                        metavar='e',
                        help='LARS coefficient (default: 0.001)')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=10,
        metavar='N',
        help='how many batches to wait before logging training status')

    args = parser.parse_args()
    use_cuda = torch.cuda.is_available()
    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")
    print(device)
    print(args)
    print("*" * 50)

    kwargs = {'num_workers': 4, 'pin_memory': True} if use_cuda else {}
    train_loader = torch.utils.data.DataLoader(
        # datasets.MNIST('../data', train=True, download=True,
        datasets.EMNIST('../data',
                        train=True,
                        download=True,
                        split='letters',
                        transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.1307, ), (0.3081, ))
                        ])),
        batch_size=args.batch_size,
        shuffle=True,
        **kwargs)
    test_loader = torch.utils.data.DataLoader(
        # datasets.MNIST('../data', train=False, transform=transforms.Compose([
        datasets.EMNIST('../data',
                        train=False,
                        split='letters',
                        transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.1307, ), (0.3081, ))
                        ])),
        batch_size=args.test_batch_size,
        shuffle=True,
        **kwargs)

    num_features = 26
    model = Net(num_outputs=num_features).to(device)

    writer = SummaryWriter(comment="_cv_%s_%s_%s" %
                           (args.optimizer, args.batch_size, args.lr))
    weight_decay = args.lr / args.epochs
    print(len(train_loader), len(test_loader))
    print(model)
    print(f'total params ---> {count_parameters(model)}')
    if args.optimizer == 'lamb':
        optimizer = Lamb(model.parameters(),
                         lr=args.lr,
                         weight_decay=args.wd,
                         betas=(.9, .999),
                         adam=False,
                         writer=writer)
    elif args.optimizer == 'lars':
        base_optimizer = torch.optim.SGD(model.parameters(),
                                         lr=args.lr,
                                         momentum=0.9)
        optimizer = LARS(optimizer=base_optimizer,
                         eps=1e-8,
                         trust_coef=0.001,
                         writer=writer)
    elif args.optimizer == 'sgd':
        optimizer = SGD(model.parameters(),
                        lr=args.lr,
                        momentum=0.9,
                        weight_decay=args.wd,
                        writer=writer)
    else:
        # use adam optimizer
        optimizer = Lamb(model.parameters(),
                         lr=args.lr,
                         weight_decay=args.wd,
                         betas=(.9, .999),
                         adam=True,
                         writer=writer)
    print(f'Currently using the {args.optimizer}\n\n')

    metrics = {"acc": [], "test_loss": []}
    os.makedirs("cv_results", exist_ok=True)

    for epoch in range(1, args.epochs + 1):
        print("Epoch #%s" % epoch)
        train(args, model, device, train_loader, optimizer, epoch, writer)
        acc, loss = test(args, model, device, test_loader, writer, epoch)
        metrics["acc"].append(acc)
        metrics["test_loss"].append(loss)
        pickle.dump(
            metrics,
            open(
                os.path.join(
                    "cv_results",
                    '%s_%s_metrics.p' % (args.optimizer, args.batch_size)),
                'wb'))
    #print("Epoch #$s: acc=%s, loss=%s" % (epoch, acc, loss))
    return optimizer
Example #14
0
def train(cfg, writer, logger):

    # Setup Augmentations
    augmentations = cfg.train.augment
    logger.info(f'using augments: {augmentations}')
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg.data.dataloader)
    data_path = cfg.data.path
    logger.info("data path: {}".format(data_path))

    t_loader = data_loader(
        data_path,
        data_format=cfg.data.format,
        norm=cfg.data.norm,
        split='train',
        split_root=cfg.data.split,
        augments=data_aug,
        logger=logger,
        log=cfg.data.log,
        ENL=cfg.data.ENL,
    )

    v_loader = data_loader(
        data_path,
        data_format=cfg.data.format,
        split='val',
        log=cfg.data.log,
        split_root=cfg.data.split,
        logger=logger,
        ENL=cfg.data.ENL,
    )

    train_data_len = len(t_loader)
    logger.info(
        f'num of train samples: {train_data_len} \nnum of val samples: {len(v_loader)}'
    )

    batch_size = cfg.train.batch_size
    epoch = cfg.train.epoch
    train_iter = int(np.ceil(train_data_len / batch_size) * epoch)
    logger.info(f'total train iter: {train_iter}')

    trainloader = data.DataLoader(t_loader,
                                  batch_size=batch_size,
                                  num_workers=cfg.train.n_workers,
                                  shuffle=True,
                                  persistent_workers=True,
                                  drop_last=True)

    valloader = data.DataLoader(
        v_loader,
        batch_size=cfg.test.batch_size,
        # persis
        num_workers=cfg.train.n_workers,
    )

    # Setup Model
    device = f'cuda:{cfg.train.gpu[0]}'
    model = get_model(cfg.model).to(device)
    input_size = (cfg.model.in_channels, 512, 512)
    logger.info(f"Using Model: {cfg.model.arch}")
    # logger.info(f'model summary: {summary(model, input_size=(input_size, input_size), is_complex=False)}')
    model = torch.nn.DataParallel(model, device_ids=cfg.gpu)  #自动多卡运行,这个好用

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in vars(cfg.train.optimizer).items()
        if k not in ('name', 'wrap')
    }
    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))
    if hasattr(cfg.train.optimizer,
               'wrap') and cfg.train.optimizer.wrap == 'lars':
        optimizer = LARS(optimizer=optimizer)
        logger.info(f'warp optimizer with {cfg.train.optimizer.wrap}')
    scheduler = get_scheduler(optimizer, cfg.train.lr)
    # loss_fn = get_loss_function(cfg)
    # logger.info(f"Using loss ,{str(cfg.train.loss)}")

    # load checkpoints
    val_cls_1_acc = 0
    best_cls_1_acc_now = 0
    best_cls_1_acc_iter_now = 0
    val_macro_OA = 0
    best_macro_OA_now = 0
    best_macro_OA_iter_now = 0
    start_iter = 0
    if cfg.train.resume is not None:
        if os.path.isfile(cfg.train.resume):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg.train.resume))

            # load model state
            checkpoint = torch.load(cfg.train.resume)
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            # best_cls_1_acc_now = checkpoint["best_cls_1_acc_now"]
            # best_cls_1_acc_iter_now = checkpoint["best_cls_1_acc_iter_now"]

            start_iter = checkpoint["epoch"]
            logger.info("Loaded checkpoint '{}' (iter {})".format(
                cfg.train.resume, checkpoint["epoch"]))

            # copy tensorboard files
            resume_src_dir = osp.split(cfg.train.resume)[0]
            # shutil.copytree(resume_src_dir, writer.get_logdir())
            for file in os.listdir(resume_src_dir):
                if not ('.log' in file or '.yml' in file
                        or '_last_model' in file):
                    # if 'events.out.tfevents' in file:
                    resume_dst_dir = writer.get_logdir()
                    fu.copy(
                        osp.join(resume_src_dir, file),
                        resume_dst_dir,
                    )

        else:
            logger.info("No checkpoint found at '{}'".format(cfg.train.resume))

    data_range = 255
    if cfg.data.log:
        data_range = np.log(data_range)
    # data_range /= 350

    # Setup Metrics
    running_metrics_val = runningScore(2)
    runing_metrics_train = runningScore(2)
    val_loss_meter = averageMeter()
    train_time_meter = averageMeter()
    train_loss_meter = averageMeter()
    val_psnr_meter = averageMeter()
    val_ssim_meter = averageMeter()

    # train
    it = start_iter
    train_start_time = time.time()
    train_val_start_time = time.time()
    model.train()
    while it < train_iter:
        for clean, noisy, _ in trainloader:
            it += 1

            noisy = noisy.to(device, dtype=torch.float32)
            # noisy /= 350
            mask1, mask2 = rand_pool.generate_mask_pair(noisy)
            noisy_sub1 = rand_pool.generate_subimages(noisy, mask1)
            noisy_sub2 = rand_pool.generate_subimages(noisy, mask2)

            # preparing for the regularization term
            with torch.no_grad():
                noisy_denoised = model(noisy)
            noisy_sub1_denoised = rand_pool.generate_subimages(
                noisy_denoised, mask1)
            noisy_sub2_denoised = rand_pool.generate_subimages(
                noisy_denoised, mask2)
            # print(rand_pool.operation_seed_counter)

            # for ii, param in enumerate(model.parameters()):
            #     if torch.sum(torch.isnan(param.data)):
            #         print(f'{ii}: nan parameters')

            # calculating the loss
            noisy_output = model(noisy_sub1)
            noisy_target = noisy_sub2
            if cfg.train.loss.gamma.const:
                gamma = cfg.train.loss.gamma.base
            else:
                gamma = it / train_iter * cfg.train.loss.gamma.base

            diff = noisy_output - noisy_target
            exp_diff = noisy_sub1_denoised - noisy_sub2_denoised
            loss1 = torch.mean(diff**2)
            loss2 = gamma * torch.mean((diff - exp_diff)**2)
            loss_all = loss1 + loss2

            # loss1 = noisy_output - noisy_target
            # loss2 = torch.exp(noisy_target - noisy_output)
            # loss_all = torch.mean(loss1 + loss2)
            loss_all.backward()

            # In PyTorch 1.1.0 and later, you should call `optimizer.step()` before `lr_scheduler.step()`
            optimizer.step()
            scheduler.step()

            # record the loss of the minibatch
            train_loss_meter.update(loss_all)
            train_time_meter.update(time.time() - train_start_time)
            writer.add_scalar('lr', optimizer.param_groups[0]['lr'], it)

            if it % 1000 == 0:
                writer.add_histogram('hist/pred', noisy_denoised, it)
                writer.add_histogram('hist/noisy', noisy, it)

                if cfg.data.simulate:
                    writer.add_histogram('hist/clean', clean, it)

            if cfg.data.simulate:
                pass

            # print interval
            if it % cfg.train.print_interval == 0:
                terminal_info = f"Iter [{it:d}/{train_iter:d}]  \
                                train Loss: {train_loss_meter.avg:.4f}  \
                                Time/Image: {train_time_meter.avg / cfg.train.batch_size:.4f}"

                logger.info(terminal_info)
                writer.add_scalar('loss/train_loss', train_loss_meter.avg, it)

                if cfg.data.simulate:
                    pass

                runing_metrics_train.reset()
                train_time_meter.reset()
                train_loss_meter.reset()

            # val interval
            if it % cfg.train.val_interval == 0 or \
               it == train_iter:
                val_start_time = time.time()
                model.eval()
                with torch.no_grad():
                    for clean, noisy, _ in valloader:
                        # noisy /= 350
                        # clean /= 350
                        noisy = noisy.to(device, dtype=torch.float32)
                        noisy_denoised = model(noisy)

                        if cfg.data.simulate:
                            clean = clean.to(device, dtype=torch.float32)
                            psnr = piq.psnr(clean,
                                            noisy_denoised,
                                            data_range=data_range)
                            ssim = piq.ssim(clean,
                                            noisy_denoised,
                                            data_range=data_range)
                            val_psnr_meter.update(psnr)
                            val_ssim_meter.update(ssim)

                        val_loss = torch.mean((noisy_denoised - noisy)**2)
                        val_loss_meter.update(val_loss)

                writer.add_scalar('loss/val_loss', val_loss_meter.avg, it)
                logger.info(
                    f"Iter [{it}/{train_iter}], val Loss: {val_loss_meter.avg:.4f} Time/Image: {(time.time()-val_start_time)/len(v_loader):.4f}"
                )
                val_loss_meter.reset()
                running_metrics_val.reset()

                if cfg.data.simulate:
                    writer.add_scalars('metrics/val', {
                        'psnr': val_psnr_meter.avg,
                        'ssim': val_ssim_meter.avg
                    }, it)
                    logger.info(
                        f'psnr: {val_psnr_meter.avg},\tssim: {val_ssim_meter.avg}'
                    )
                    val_psnr_meter.reset()
                    val_ssim_meter.reset()

                train_val_time = time.time() - train_val_start_time
                remain_time = train_val_time * (train_iter - it) / it
                m, s = divmod(remain_time, 60)
                h, m = divmod(m, 60)
                if s != 0:
                    train_time = "Remain train time = %d hours %d minutes %d seconds \n" % (
                        h, m, s)
                else:
                    train_time = "Remain train time : train completed.\n"
                logger.info(train_time)
                model.train()

            # save model
            if it % (train_iter / cfg.train.epoch * 10) == 0:
                ep = int(it / (train_iter / cfg.train.epoch))
                state = {
                    "epoch": it,
                    "model_state": model.state_dict(),
                    "optimizer_state": optimizer.state_dict(),
                    "scheduler_state": scheduler.state_dict(),
                }
                save_path = osp.join(writer.file_writer.get_logdir(),
                                     f"{ep}.pkl")
                torch.save(state, save_path)
                logger.info(f'saved model state dict at {save_path}')

            train_start_time = time.time()
Example #15
0
    resume_from = os.path.join('./checkpoint', args.resume)
    checkpoint = torch.load(resume_from)
    net.load_state_dict(checkpoint['net'])
    critic.load_state_dict(checkpoint['critic'])
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch']

criterion = nn.CrossEntropyLoss()
base_optimizer = optim.SGD(list(net.parameters()) + list(critic.parameters()),
                           lr=args.lr,
                           weight_decay=1e-6,
                           momentum=args.momentum)
if args.cosine_anneal:
    scheduler = CosineAnnealingWithLinearRampLR(base_optimizer,
                                                args.num_epochs)
encoder_optimizer = LARS(base_optimizer, trust_coef=1e-3)


# Training
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    critic.train()
    train_loss = 0
    t = tqdm(enumerate(trainloader),
             desc='Loss: **** ',
             total=len(trainloader),
             bar_format='{desc}{bar}{r_bar}')
    for batch_idx, (inputs, _, _) in t:
        x1, x2 = inputs
        x1, x2 = x1.to(device), x2.to(device)
Example #16
0
simclr_aug = C.get_simclr_augmentation(P, image_size=P.image_size).to(device)
P.shift_trans, P.K_shift = C.get_shift_module(P, eval=True)
P.shift_trans = P.shift_trans.to(device)

model = C.get_classifier(P.model, n_classes=P.n_classes).to(device)
model = C.get_shift_classifer(model, P.K_shift).to(device)

criterion = nn.CrossEntropyLoss().to(device)

if P.optimizer == 'sgd':
    optimizer = optim.SGD(model.parameters(), lr=P.lr_init, momentum=0.9, weight_decay=P.weight_decay)
    lr_decay_gamma = 0.1
elif P.optimizer == 'lars':
    from torchlars import LARS
    base_optimizer = optim.SGD(model.parameters(), lr=P.lr_init, momentum=0.9, weight_decay=P.weight_decay)
    optimizer = LARS(base_optimizer, eps=1e-8, trust_coef=0.001)
    lr_decay_gamma = 0.1
else:
    raise NotImplementedError()

if P.lr_scheduler == 'cosine':
    scheduler = lr_scheduler.CosineAnnealingLR(optimizer, P.epochs)
elif P.lr_scheduler == 'step_decay':
    milestones = [int(0.5 * P.epochs), int(0.75 * P.epochs)]
    scheduler = lr_scheduler.MultiStepLR(optimizer, gamma=lr_decay_gamma, milestones=milestones)
else:
    raise NotImplementedError()

from training.scheduler import GradualWarmupScheduler
scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=10.0, total_epoch=P.warmup, after_scheduler=scheduler)
    val_loader = DataLoader(val_dataset,
                            batch_size=args.batch_size,
                            shuffle=True,
                            pin_memory=True)
    loaders = [train_loader, val_loader]

    print('Total number of training images: {}, validation images: {}.'.format(
        len(train_dataset), len(val_dataset)))

    model = resnet18_cifar(args.feature_size).to(device)
    model = torch.nn.DataParallel(model)

    base_optimizer = optim.Adam(model.parameters(),
                                lr=args.lr,
                                weight_decay=args.decay_lr)
    optimizer = LARS(optimizer=base_optimizer, eps=1e-8, trust_coef=0.001)
    scheduler = ExponentialLR(optimizer, gamma=args.decay_lr)

    # Main training loop
    best_loss = np.inf

    for epoch in range(args.epochs):
        v_loss = execute_graph(model, loaders, optimizer, scheduler, epoch)

        if v_loss < best_loss:
            best_loss = v_loss
            print('Writing model checkpoint')
            state = {
                'epoch': epoch,
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
class task2_train():
    def __init__(self, device, train_loader, val_loader, model, MODEL_PATH):
        self.device = device
        self.model_path = MODEL_PATH
        self.train_loader = train_loader
        self.val_loader = val_loader

        self.model = model.to(device)
        base_optimizer = optim.Adam(self.model.parameters(), lr=2e-4)
        self.optimizer = LARS(optimizer=base_optimizer,
                              eps=1e-8,
                              trust_coef=0.001)
        self.criterion = nn.BCEWithLogitsLoss()
        self.scaler = GradScaler()

    def train(self):
        self.model.train()
        train_loss = 0

        label_pred = []
        label_true = []

        AC_total, SE_total, SP_total, DC_total, JS_total = 0, 0, 0, 0, 0

        start = time.time()

        for batch_idx, (inputs, GT) in enumerate(self.train_loader):
            if batch_idx % 10 == 0 and batch_idx:
                end = time.time()
                print("finished training %s images, using %.4fs" %
                      (str(batch_idx * 8), end - start))
            inputs, GT = inputs.to(self.device), GT.to(self.device)
            self.optimizer.zero_grad()
            with autocast():
                SR = self.model(inputs)
                loss = loss_f(inputs=SR, targets=GT.long())

            self.scaler.scale(loss).backward()
            self.scaler.step(self.optimizer)
            self.scaler.update()

            train_loss += loss.item()

            label_pred.append(SR.data.max(dim=1)[1].cpu().numpy())
            label_true.append(GT.data.cpu().numpy())

        label_pred = np.concatenate(label_pred, axis=0)
        label_true = np.concatenate(label_true, axis=0)

        for lbt, lbp in zip(label_true, label_pred):
            AC, SE, SP, DC, JS = self.evaluate(lbt, lbp)
            AC_total += AC
            SE_total += SE
            SP_total += SP
            DC_total += DC
            JS_total += JS

        print(
            'Training Loss: %.4f\nAC: %.4f, SE: %.4f, SP: %.4f, DC: %.4f, JS: %.4f'
            % (train_loss, AC_total / len(label_true),
               SE_total / len(label_true), SP_total / len(label_true),
               DC_total / len(label_true), JS_total / len(label_true)))

    def _fast_hist(self, truth, pred):
        mask = (truth >= 0) & (truth < 19)
        hist = np.bincount(19 * truth[mask].astype(int) + pred[mask],
                           minlength=19**2).reshape(19, 19)
        return hist

    def evaluate(self, ground_truth, predictions, smooth=1):

        confusion_matrix = np.zeros((19, 19))

        for lt, lp in zip(ground_truth, predictions):
            confusion_matrix += self._fast_hist(lt.flatten(), lp.flatten())

        fn = confusion_matrix.sum(1) - np.diag(confusion_matrix)
        fp = confusion_matrix.sum(0) - np.diag(confusion_matrix)
        tp = np.diag(confusion_matrix)
        tn = np.array([confusion_matrix.sum() for i in range(19)
                       ]) - confusion_matrix.sum(1) - confusion_matrix.sum(
                           0) + np.diag(confusion_matrix)

        AC_array = (tp + tn) / np.maximum(1.0, fn + fp + tp + tn)
        AC = AC_array.mean()
        SE_array = (tp) / np.maximum(1.0, confusion_matrix.sum(1))
        SE = SE_array.mean()
        SP_array = tn / np.maximum(1.0, tn + fp)
        SP = SP_array.mean()
        DC_array = (2 * tp) / np.maximum(
            1.0,
            confusion_matrix.sum(0) + confusion_matrix.sum(1))
        DC = DC_array.mean()
        JS_array = (tp) / np.maximum(
            1.0,
            confusion_matrix.sum(0) + confusion_matrix.sum(1) -
            np.diag(confusion_matrix))

        JS = JS_array.mean()

        return AC, SE, SP, DC, JS

    def saveModel(self):
        torch.save(self.model.state_dict(), self.model_path + "model.pth")
Example #19
0
                      momentum=0.9,
                      weight_decay=args.weight_decay)
    lr_decay_gamma = 0.1
elif args.optimizer == 'adam':
    optim = optim.Adam(model.parameters(),
                       lr=args.lr,
                       betas=(.9, .999),
                       weight_decay=args.weight_decay)
    lr_decay_gamma = 0.3
elif args.optimizer == 'lars':
    from torchlars import LARS
    base_optimizer = optim.SGD(model.parameters(),
                               lr=args.lr,
                               momentum=0.9,
                               weight_decay=args.weight_decay)
    optim = LARS(base_optimizer, eps=1e-8, trust_coef=0.001)
    lr_decay_gamma = 0.1
elif args.optimizer == 'ranger':
    from ranger import Ranger
    optim = Ranger(model.parameters(),
                   weight_decay=args.weight_decay,
                   lr=args.lr)
else:
    raise NotImplementedError()

# normal_class=[args.known_normal], known_outlier_class=args.known_outlier,
print("known_normal:", args.known_normal, "known_outlier:", args.known_outlier)
# if args.lr_scheduler == 'cosine':
# scheduler = lr_scheduler.CosineAnnealingLR(optim, args.n_epochs)
# print("use cosine scheduler")
# Evaluation before training
Example #20
0
    args.add_argument('--seed', type=int)
    args.add_argument('--lr', type=float, default=1e-4)
    config = args.parse_args()

    if 'DEVICE' in os.environ:
        config.device = os.environ['DEVICE']

    wandb.init(project='EfficientNet_MNIST', config=config)

    if config.seed is not None:
        torch.manual_seed(config.seed)

    eff_net = efficient_net.EfficientNet(version="b0",
                                         num_classes=10).to(config.device)
    if config.lars:
        optim = LARS(Adam(eff_net.parameters(), lr=config.lr))
    else:
        optim = Adam(eff_net.parameters(), lr=config.lr)

    ds = MNIST('~/.mnist',
               train=True,
               download=True,
               transform=transforms.ToTensor())
    dl = DataLoader(ds, batch_size=config.batch_size)

    test_s = MNIST('~/.mnist',
                   train=False,
                   download=True,
                   transform=transforms.ToTensor())
    test = DataLoader(test_s, batch_size=config.batch_size)