def train(gpu=1, args):
    
    rank = args.nr * args.num_gpus + gpu

    dist.init_process_group(backend="nccl", world_size=args.world_size, rank=rank)

    if args.batch_size == 1 and args.use_bn is True:
        raise Exception

    torch.autograd.set_detect_anomaly(True)
    torch.manual_seed(args.torch_seed)
    torch.cuda.manual_seed(args.cuda_seed)
    
    torch.cuda.set_device(gpu)
    
    DATASET_NAME = args.dataset_name
    DATA_ROOT = args.data_root
    OVERLAYS_ROOT = args.overlays_root

    if args.model_name == 'dss':
        model = DSS_Net(args, n_channels=3, n_classes=1, bilinear=True)
        loss = FocalLoss()
    elif args.model_name == 'unet':
        model = UNet(args)
        loss = nn.BCELoss()
    else:
        raise NotImplementedError

    #model = nn.SyncBatchNorm(model)

    print(f"Using {torch.cuda.device_count()} GPUs...")
        
    # define dataset
    if DATASET_NAME == 'synthetic':
        assert (args.overlays_root != "")
        #train_dataset = SimpleSmokeTrain(args, dataset_limit=args.num_examples)
        train_dataset = SyntheticSmokeTrain(args, DATA_ROOT, OVERLAYS_ROOT, dataset_limit=args.num_examples)
        train_sampler = DistributedSampler(train_dataset, num_replicas=args.world_size, rank=rank, shuffle=True)
        train_dataloader = DataLoader(train_dataset, args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True, sampler=train_sampler)
        if args.validate:
            val_dataset = SmokeDataset()
        else:
            val_dataset = None
        val_dataloader = DataLoader(val_dataset, args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) if val_dataset else None
    else:
        raise NotImplementedError

    # define augmentations
    augmentations = None #SyntheticAugmentation(args)
    
    # load the model
    print("Loding model and augmentations and placing on gpu...")

    if args.cuda:
        if augmentations is not None:
            augmentations = augmentations.cuda()

        model = model.cuda(device=gpu)
            
        if args.num_gpus > 0 or torch.cuda.device_count() > 0:
            model = DistributedDataParallel(model, device_ids=[gpu])
                
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"The model has {num_params} learnable parameters")

    # load optimizer and lr scheduler
    optimizer = Adam(model.parameters(), lr=args.lr, betas=[args.momentum, args.beta], weight_decay=args.weight_decay)

    if args.lr_sched_type == 'plateau':
        print("Using plateau lr schedule")
        lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, factor=args.lr_gamma, verbose=True, mode='min', patience=10)
    elif args.lr_sched_type == 'step':
        print("Using step lr schedule")
        milestones = [30]
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=milestones, gamma=args.lr_gamma)
    elif args.lr_sched_type == 'none':
        lr_scheduler = None

    # set up logging
    if not args.no_logging and gpu == 0:
        if not os.path.isdir(args.log_dir):
            os.mkdir(args.log_dir)
        log_dir = os.path.join(args.log_dir, args.exp_dir)
        if not os.path.isdir(log_dir):
            os.mkdir(log_dir)
        if args.exp_name == "":
            exp_name = datetime.datetime.now().strftime("%H%M%S-%Y%m%d")
        else:
            exp_name = args.exp_name
        log_dir = os.path.join(log_dir, exp_name)
        writer = SummaryWriter(log_dir)

    if args.ckpt != "" and args.use_pretrained:
        state_dict = torch.load(args.ckpt)['state_dict']
        model.load_state_dict(state_dict)
    elif args.start_epoch > 0:
        load_epoch = args.start_epoch - 1
        ckpt_fp = os.path.join(log_dir, f"{load_epoch}.ckpt")

        print(f"Loading model from {ckpt_fp}...")

        ckpt = torch.load(ckpt_fp)
        assert (ckpt['epoch'] ==
                load_epoch), "epoch from state dict does not match with args"
        model.load_state_dict(ckpt)

    model.train()
    
    # run training loop
    for epoch in range(args.start_epoch, args.epochs + 1):
        print(f"Training epoch: {epoch}...")
        train_sampler.set_epoch(epoch)
        train_loss_avg, pred_mask, input_dict = train_one_epoch(
            args, model, loss, train_dataloader, optimizer, augmentations, lr_scheduler)
        if gpu == 0:
            print(f"\t Epoch {epoch} train loss avg:")
            pprint(train_loss_avg)

        if val_dataset is not None:
            print(f"Validation epoch: {epoch}...")
            val_loss_avg = eval(args, model, loss, val_dataloader, augmentations)
            print(f"\t Epoch {epoch} val loss avg: {val_loss_avg}")

        if not args.no_logging and gpu == 0:
            writer.add_scalar(f'loss/train', train_loss_avg, epoch)
            if epoch % args.log_freq == 0:
                visualize_output(args, input_dict, pred_mask, epoch, writer)

        if args.lr_sched_type == 'plateau':
            lr_scheduler.step(train_loss_avg_dict['total_loss'])
        elif args.lr_sched_type == 'step':
            lr_scheduler.step(epoch)

        # save model
        if not args.no_logging and gpu == 0:
            if epoch % args.save_freq == 0 or epoch == args.epochs:
                fp = os.path.join(log_dir, f"{epoch}.ckpt")
                torch.save(model.state_dict(), fp)

            writer.flush()

    return
Esempio n. 2
0
    #                               init_lamda=1.5*torch.ones((2, 4), dtype=torch.float)).cuda()

    fc = torch.nn.Linear(2, 2).cuda()
    layer = Membership_norm(2, 4,
                            init_c=-5 * torch.ones((2, 4), dtype=torch.float),
                            init_lamda=4 * torch.ones((2, 4), dtype=torch.float)).cuda()

    # x = torch.tensor([[[0.9, 0.1], [0.9, 0.1]], [[-0.9, 0.1], [-0.1, -2.5]]], dtype=torch.float, requires_grad=True)
    # x2 = x ** 2
    # print(x2.requires_grad)
    # print(x.shape)
    # print(layer(x))
    # print(layer.c)
    # print(x.shape)
    # loss_focal = torch.nn.MSELoss()
    loss_focal = FocalLoss()
    loss_center = CenterLoss()
    para = [
        {"params": fc.parameters(), "lr": 1e-3},
        {"params": layer.c, "lr": 1e-3},
        {"params": layer.lamda, "lr": 1e-3},
    ]

    # optim = torch.optim.SGD(para)
    optim = torch.optim.Adam(para)
    # bestloss = 1e5
    # bestnetweightfc = []
    # bestnetweightlayer = []

    for i in range(0, 100000):
        h = fc(x_in_tensor).unsqueeze(2)
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    # create model
    print("=> creating model '{}'".format(args.arch))
    num_classes = 100 if args.dataset == 'cifar100' else 10
    use_norm = True if args.loss_type == 'LDAM' else False
    model = models.__dict__[args.arch](num_classes=num_classes,
                                       use_norm=use_norm)
    non_gpu_model = models.__dict__[args.arch](num_classes=num_classes,
                                               use_norm=use_norm)

    if args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        model = torch.nn.DataParallel(model).cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cuda:0')
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            if args.gpu is not None:
                # best_acc1 may be from a checkpoint from a different GPU
                best_acc1 = best_acc1.to(args.gpu)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # Data loading code

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    transform_val = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    if args.dataset == 'cifar10':
        train_dataset = IMBALANCECIFAR10(root='./data',
                                         imb_type=args.imb_type,
                                         imb_factor=args.imb_factor,
                                         rand_number=args.rand_number,
                                         train=True,
                                         download=True,
                                         transform=transform_train)
        val_dataset = datasets.CIFAR10(root='./data',
                                       train=False,
                                       download=True,
                                       transform=transform_val)
    elif args.dataset == 'cifar100':
        train_dataset = IMBALANCECIFAR100(root='./data',
                                          imb_type=args.imb_type,
                                          imb_factor=args.imb_factor,
                                          rand_number=args.rand_number,
                                          train=True,
                                          download=True,
                                          transform=transform_train)
        val_dataset = datasets.CIFAR100(root='./data',
                                        train=False,
                                        download=True,
                                        transform=transform_val)
    else:
        warnings.warn('Dataset is not listed')
        return
    cls_num_list = train_dataset.get_cls_num_list()
    print('cls num list:')
    print(cls_num_list)
    args.cls_num_list = cls_num_list

    train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=100,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # init log for training
    log_training = open(
        os.path.join(args.root_log, args.store_name, 'log_train.csv'), 'w')
    log_testing = open(
        os.path.join(args.root_log, args.store_name, 'log_test.csv'), 'w')
    with open(os.path.join(args.root_log, args.store_name, 'args.txt'),
              'w') as f:
        f.write(str(args))
    tf_writer = SummaryWriter(
        log_dir=os.path.join(args.root_log, args.store_name))
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args)

        if args.train_rule == 'None':
            train_sampler = None
            per_cls_weights = None
        elif args.train_rule == 'Resample':
            train_sampler = ImbalancedDatasetSampler(train_dataset)
            per_cls_weights = None
        elif args.train_rule == 'Reweight':
            train_sampler = None
            beta = 0.9999
            effective_num = 1.0 - np.power(beta, cls_num_list)
            per_cls_weights = (1.0 - beta) / np.array(effective_num)
            per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(
                cls_num_list)
            per_cls_weights = torch.FloatTensor(per_cls_weights).cuda(args.gpu)
        elif args.train_rule == 'DRW':
            train_sampler = None
            idx = min(epoch // 160, 1)
            betas = [0, 0.9999]
            effective_num = 1.0 - np.power(betas[idx], cls_num_list)
            per_cls_weights = (1.0 - betas[idx]) / np.array(effective_num)
            per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(
                cls_num_list)
            per_cls_weights = torch.FloatTensor(per_cls_weights).cuda(args.gpu)
        elif args.train_rule == 'SWITCHING_DRW_AUTO_CLUSTER':
            train_sampler = None
            cutoff_epoch = args.cutoff_epoch
            idx = min(epoch // cutoff_epoch, 1)
            betas = [0, 0.9999]
            if epoch >= cutoff_epoch + 10 and (epoch - cutoff_epoch) % 20 == 0:
                max_real_ratio_number_of_labels = 20
                # todo: transform data batch by batch, then concatentate...
                temp_batch_size = int(train_dataset.data.shape[0])
                temp_train_loader = torch.utils.data.DataLoader(
                    train_dataset,
                    batch_size=int(temp_batch_size),
                    shuffle=(train_sampler is None),
                    num_workers=args.workers,
                    pin_memory=True,
                    sampler=train_sampler)

                transformed_data = None
                transformed_labels = None
                for i, (xs, labels) in enumerate(train_loader):
                    transformed_batch_data = model.forward_partial(
                        xs.cuda(), num_layers=6)
                    transformed_batch_data = transformed_batch_data.cpu(
                    ).detach()
                    transformed_batch_data = transformed_batch_data.numpy()
                    transformed_batch_data = np.reshape(
                        transformed_batch_data,
                        (transformed_batch_data.shape[0], -1))
                    labels = np.array(labels)[:, np.newaxis]
                    if transformed_data is None:
                        transformed_data = transformed_batch_data
                        transformed_labels = labels
                    else:
                        transformed_data = np.vstack(
                            (transformed_data, transformed_batch_data))
                        # print(labels.shape)
                        # print(transformed_labels.shape)
                        transformed_labels = np.vstack(
                            (transformed_labels, labels))

                xmean_model = xmeans(data=transformed_data,
                                     kmax=num_classes *
                                     max_real_ratio_number_of_labels)
                xmean_model.process()
                # Extract clustering results: clusters and their centers
                clusters = xmean_model.get_clusters()
                centers = xmean_model.get_centers()
                new_labels = []
                xs = transformed_data
                centers = np.array(centers)
                print("number of clusters: ", len(centers))
                squared_norm_dist = np.sum((xs - centers[:, np.newaxis])**2,
                                           axis=2)
                data_centers = np.argmin(squared_norm_dist, axis=0)
                data_centers = np.expand_dims(data_centers, axis=1)

                new_labels = []
                for i in range(len(transformed_labels)):
                    new_labels.append(data_centers[i][0] +
                                      transformed_labels[i][0] * len(centers))

                new_label_counts = {}
                for label in new_labels:
                    if label in new_label_counts.keys():
                        new_label_counts[label] += 1
                    else:
                        new_label_counts[label] = 1

                # print(new_label_counts)

                per_cls_weights = []
                for i in range(len(cls_num_list)):
                    temp = []
                    for j in range(len(centers)):
                        new_label = j + i * len(centers)
                        if new_label in new_label_counts:
                            temp.append(new_label_counts[new_label])
                    effective_num_temp = 1.0 - np.power(betas[idx], temp)
                    per_cls_weights_temp = (
                        1.0 - betas[idx]) / np.array(effective_num_temp)
                    per_cls_weights.append(
                        np.average(per_cls_weights_temp, weights=temp))
                per_cls_weights = per_cls_weights / np.sum(
                    per_cls_weights) * len(cls_num_list)
                per_cls_weights = torch.FloatTensor(per_cls_weights).cuda(
                    args.gpu)
            elif epoch < cutoff_epoch or (epoch - cutoff_epoch) % 20 == 10:
                effective_num = 1.0 - np.power(betas[idx], cls_num_list)
                per_cls_weights = (1.0 - betas[idx]) / np.array(effective_num)
                per_cls_weights = per_cls_weights / np.sum(
                    per_cls_weights) * len(cls_num_list)
                per_cls_weights = torch.FloatTensor(per_cls_weights).cuda(
                    args.gpu)

        elif args.train_rule == 'DRW_AUTO_CLUSTER':
            train_sampler = None
            cutoff_epoch = args.cutoff_epoch
            idx = min(epoch // cutoff_epoch, 1)
            betas = [0, 0.9999]
            if epoch >= cutoff_epoch and (epoch - cutoff_epoch) % 10 == 0:
                max_real_ratio_number_of_labels = 20
                # todo: transform data batch by batch, then concatentate...
                temp_batch_size = int(train_dataset.data.shape[0])
                temp_train_loader = torch.utils.data.DataLoader(
                    train_dataset,
                    batch_size=int(temp_batch_size),
                    shuffle=(train_sampler is None),
                    num_workers=args.workers,
                    pin_memory=True,
                    sampler=train_sampler)

                transformed_data = None
                transformed_labels = None
                for i, (xs, labels) in enumerate(train_loader):
                    transformed_batch_data = model.forward_partial(
                        xs.cuda(), num_layers=6)
                    transformed_batch_data = transformed_batch_data.cpu(
                    ).detach()
                    transformed_batch_data = transformed_batch_data.numpy()
                    transformed_batch_data = np.reshape(
                        transformed_batch_data,
                        (transformed_batch_data.shape[0], -1))
                    labels = np.array(labels)[:, np.newaxis]
                    if transformed_data is None:
                        transformed_data = transformed_batch_data
                        transformed_labels = labels
                    else:
                        transformed_data = np.vstack(
                            (transformed_data, transformed_batch_data))
                        # print(labels.shape)
                        # print(transformed_labels.shape)
                        transformed_labels = np.vstack(
                            (transformed_labels, labels))

                initial_centers = [
                    np.zeros((transformed_data.shape[1], ))
                    for i in range(num_classes)
                ]
                center_counts = [0 for i in range(num_classes)]
                for i in range(transformed_data.shape[0]):
                    temp_idx = transformed_labels[i][0]
                    initial_centers[temp_idx] = initial_centers[
                        temp_idx] + transformed_data[i, :]
                    center_counts[temp_idx] = center_counts[temp_idx] + 1

                for i in range(num_classes):
                    initial_centers[i] = initial_centers[i] / center_counts[i]

                xmean_model = xmeans(data=transformed_data, initial_centers=initial_centers, \
                                     kmax=num_classes * max_real_ratio_number_of_labels)
                xmean_model.process()
                # Extract clustering results: clusters and their centers
                clusters = xmean_model.get_clusters()
                centers = xmean_model.get_centers()
                new_labels = []
                xs = transformed_data
                centers = np.array(centers)
                print("number of clusters: ", len(centers))
                squared_norm_dist = np.sum((xs - centers[:, np.newaxis])**2,
                                           axis=2)
                data_centers = np.argmin(squared_norm_dist, axis=0)
                data_centers = np.expand_dims(data_centers, axis=1)

                new_labels = []
                for i in range(len(transformed_labels)):
                    new_labels.append(data_centers[i][0] +
                                      transformed_labels[i][0] * len(centers))

                new_label_counts = {}
                for label in new_labels:
                    if label in new_label_counts.keys():
                        new_label_counts[label] += 1
                    else:
                        new_label_counts[label] = 1

                # print(new_label_counts)

                per_cls_weights = []
                for i in range(len(cls_num_list)):
                    temp = []
                    for j in range(len(centers)):
                        new_label = j + i * len(centers)
                        if new_label in new_label_counts:
                            temp.append(new_label_counts[new_label])
                    effective_num_temp = 1.0 - np.power(betas[idx], temp)
                    per_cls_weights_temp = (
                        1.0 - betas[idx]) / np.array(effective_num_temp)
                    per_cls_weights.append(
                        np.average(per_cls_weights_temp, weights=temp))
                per_cls_weights = per_cls_weights / np.sum(
                    per_cls_weights) * len(cls_num_list)
                per_cls_weights = torch.FloatTensor(per_cls_weights).cuda(
                    args.gpu)
            elif epoch < cutoff_epoch:
                effective_num = 1.0 - np.power(betas[idx], cls_num_list)
                per_cls_weights = (1.0 - betas[idx]) / np.array(effective_num)
                per_cls_weights = per_cls_weights / np.sum(
                    per_cls_weights) * len(cls_num_list)
                per_cls_weights = torch.FloatTensor(per_cls_weights).cuda(
                    args.gpu)

        else:
            warnings.warn('Sample rule is not listed')

        if args.loss_type == 'CE':
            criterion = nn.CrossEntropyLoss(weight=per_cls_weights).cuda(
                args.gpu)
        elif args.loss_type == 'LDAM':
            criterion = LDAMLoss(cls_num_list=cls_num_list,
                                 max_m=0.5,
                                 s=30,
                                 weight=per_cls_weights).cuda(args.gpu)
            #temp = [cls_num_list[i] * per_cls_weights[i].item() for i in range(len(cls_num_list))]
            #criterion = LDAMLoss(cls_num_list=temp, max_m=0.5, s=30, weight=per_cls_weights).cuda(args.gpu)
        elif args.loss_type == 'Focal':
            criterion = FocalLoss(weight=per_cls_weights,
                                  gamma=1).cuda(args.gpu)
        else:
            warnings.warn('Loss type is not listed')
            return

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, args,
              log_training, tf_writer)

        # evaluate on validation set
        acc1 = validate(val_loader, model, criterion, epoch, args, log_testing,
                        tf_writer)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        tf_writer.add_scalar('acc/test_top1_best', best_acc1, epoch)
        output_best = 'Best Prec@1: %.3f\n' % (best_acc1)
        print(output_best)
        log_testing.write(output_best + '\n')
        log_testing.flush()

        save_checkpoint(
            args, {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer': optimizer.state_dict(),
            }, is_best)
Esempio n. 4
0
    # Define Loaders
    train_loader = sample_loader(train_data, train_mask, train_label, model_cfg.batch_size, isMasking=True, mask_ratio=mask_ratio, ZeroImpute=True)
    valid_loader = sample_loader(valid_data, valid_mask, valid_label, model_cfg.batch_size, isMasking=False, mask_ratio=mask_ratio, ZeroImpute=True)
    test_loader = sample_loader(test_data, test_mask, test_label, model_cfg.batch_size, isMasking=False, mask_ratio=mask_ratio, ZeroImpute=True)

    # Define Model & Optimizer
    if dataset == 'mimic':
        model = BiVRNN(x_dim, model_cfg.h_dim, model_cfg.z1_dim, n_layers, out_ch=model_cfg.out_ch, dropout_p=model_cfg.dropout_p, isdecaying=isdecaying,
                        FFA=FFA, isreparam=isreparam, issampling=False, device=device).to(device)
    else:
        model = BiVRNN(x_dim, model_cfg.h_dim, model_cfg.z_dim, n_layers, out_ch=model_cfg.out_ch, dropout_p=model_cfg.dropout_p, isdecaying=isdecaying,
                        FFA=FFA, isreparam=isreparam, issampling=False, device=device).to(device)

    vrnn_loss = VRNNLoss(lambda1, device, isreconmsk=isreconmsk)

    classification_loss = FocalLoss(lambda1, device, alpha, gamma, logits=False)

    # Reset Best AUC
    bestValidAUC = 0
    best_epoch = 0

    optimizer = RAdam(list(model.parameters()), lr=model_cfg.org_learning_rate, weight_decay=model_cfg.w_decay)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=model_cfg.learning_rate_decay, gamma=0.5)


    for epoch in range(1, model_cfg.n_epochs+1):
        writelog(f, '------ Epoch ' + str(epoch))

        writelog(f, 'Training')
        train(epoch, train_loader)
def test(gpu, args):
    print("Starting...")
    print("Using {} percent of data for testing".format(100 -
                                                        args.training_frac *
                                                        100))
    torch.autograd.set_detect_anomaly(True)
    #torch.manual_seed(args.torch_seed)
    #torch.cuda.manual_seed(args.cuda_seed)
    torch.cuda.set_device(gpu)

    DATA_ROOT = args.data_root
    NUM_IMAGES = args.num_test_images
    CHKPNT_PTH = args.checkpoint_path

    if args.model_name == 'dss':
        model = DSS_Net(args, n_channels=3, n_classes=1, bilinear=True)
        loss = FocalLoss()
    elif args.model_name == 'unet':
        model = UNet(args)
        loss = nn.BCELoss()
    else:
        raise NotImplementedError

    state_dict = torch.load(CHKPNT_PTH)
    new_state_dict = copy_state_dict(state_dict)  #OrderedDict()
    # for k, v in state_dict.items():
    #     name = k[7:]
    #     names = name.strip().split('.')
    #     if names[1] == 'inc':
    #         names[1] = 'conv1'
    #     name = '.'.join(names)
    #     # print(names)
    #     new_state_dict[name] = v

    # print("Expected values:", model.state_dict().keys())
    model.load_state_dict(new_state_dict)
    model.cuda(gpu)
    model.eval()

    if args.test_loader == 'annotated':
        dataset = SmokeDataset(dataset_limit=NUM_IMAGES,
                               training_frac=args.training_frac)
        # dataset = SyntheticSmokeTrain(args,dataset_limit=50)
    else:
        dataset = SimpleSmokeVal(args=args,
                                 data_root=DATA_ROOT,
                                 dataset_limit=NUM_IMAGES)
    dataloader = DataLoader(dataset,
                            1,
                            shuffle=True,
                            num_workers=4,
                            pin_memory=True)  #, sampler=train_sampler)

    # if not args.no_logging and gpu == 1:
    if not os.path.isdir(args.log_dir):
        os.mkdir(args.log_dir)
    log_dir = os.path.join(args.log_dir, args.exp_dir)
    if not os.path.isdir(log_dir):
        os.mkdir(log_dir)
    if args.exp_name == "":
        exp_name = datetime.datetime.now().strftime("%H%M%S-%Y%m%d")
    else:
        exp_name = args.exp_name
        log_dir = os.path.join(log_dir, exp_name)
    writer = SummaryWriter(log_dir)

    iou_sum = 0
    iou_count = 0
    iou_ = 0
    for idx, data in enumerate(dataloader):
        if args.test_loader == 'annotated':
            out_img, iou_ = val_step_with_loss(data, model)
            iou_sum += iou_
            iou_count += 1
            writer.add_images('true_mask', data['target_mask'] > 0, idx)
        else:
            out_img = val_step(data, model)
        writer.add_images('input_img', data['input_img'], idx)
        writer.add_images('pred_mask', out_img, idx)
        writer.add_scalar(f'accuracy/test', iou_, idx)
        writer.flush()
        # print("Step: {}/{}: IOU: {}".format(idx,len(dataloader), iou_))
        if idx > len(dataloader):
            break
    if iou_count > 0:
        iou = iou_sum / iou_count
        writer.add_scalar(f'mean_accuracy/test', iou)
        print("Mean IOU: ", iou)
    print("Done")
Esempio n. 6
0
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    # create model
    print("=> creating model '{}'".format(args.arch))
    num_classes = 100 if args.dataset == 'cifar100' else 10
    use_norm = True if args.loss_type == 'LDAM' else False
    model = models.__dict__[args.arch](num_classes=num_classes,
                                       use_norm=use_norm)

    if args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        model = torch.nn.DataParallel(model).cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cuda:0')
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            if args.gpu is not None:
                # best_acc1 may be from a checkpoint from a different GPU
                best_acc1 = best_acc1.to(args.gpu)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # Data loading code

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    transform_val = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    if args.dataset == 'cifar10':
        train_dataset = IMBALANCECIFAR10(root='./data',
                                         imb_type=args.imb_type,
                                         imb_factor=args.imb_factor,
                                         rand_number=args.rand_number,
                                         train=True,
                                         download=True,
                                         transform=transform_train)
        val_dataset = datasets.CIFAR10(root='./data',
                                       train=False,
                                       download=True,
                                       transform=transform_val)
    elif args.dataset == 'cifar100':
        train_dataset = IMBALANCECIFAR100(root='./data',
                                          imb_type=args.imb_type,
                                          imb_factor=args.imb_factor,
                                          rand_number=args.rand_number,
                                          train=True,
                                          download=True,
                                          transform=transform_train)
        val_dataset = datasets.CIFAR100(root='./data',
                                        train=False,
                                        download=True,
                                        transform=transform_val)
    else:
        warnings.warn('Dataset is not listed')
        return
    cls_num_list = train_dataset.get_cls_num_list()
    print('cls num list:')
    print(cls_num_list)
    args.cls_num_list = cls_num_list

    train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=100,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # init log for training
    log_training = open(
        os.path.join(args.root_log, args.store_name, 'log_train.csv'), 'w')
    log_testing = open(
        os.path.join(args.root_log, args.store_name, 'log_test.csv'), 'w')
    with open(os.path.join(args.root_log, args.store_name, 'args.txt'),
              'w') as f:
        f.write(str(args))
    tf_writer = SummaryWriter(
        log_dir=os.path.join(args.root_log, args.store_name))
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args)

        if args.train_rule == 'None':
            train_sampler = None
            per_cls_weights = None
        elif args.train_rule == 'Resample':
            train_sampler = ImbalancedDatasetSampler(train_dataset)
            per_cls_weights = None
        elif args.train_rule == 'Reweight':
            train_sampler = None
            beta = 0.9999
            effective_num = 1.0 - np.power(beta, cls_num_list)
            per_cls_weights = (1.0 - beta) / np.array(effective_num)
            per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(
                cls_num_list)
            per_cls_weights = torch.FloatTensor(per_cls_weights).cuda(args.gpu)
        elif args.train_rule == 'DRW':
            train_sampler = None
            idx = epoch // 160
            betas = [0, 0.9999]
            effective_num = 1.0 - np.power(betas[idx], cls_num_list)
            per_cls_weights = (1.0 - betas[idx]) / np.array(effective_num)
            per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(
                cls_num_list)
            per_cls_weights = torch.FloatTensor(per_cls_weights).cuda(args.gpu)
        else:
            warnings.warn('Sample rule is not listed')

        if args.loss_type == 'CE':
            criterion = nn.CrossEntropyLoss(weight=per_cls_weights).cuda(
                args.gpu)
        elif args.loss_type == 'LDAM':
            criterion = LDAMLoss(cls_num_list=cls_num_list,
                                 max_m=0.5,
                                 s=30,
                                 weight=per_cls_weights).cuda(args.gpu)
        elif args.loss_type == 'Focal':
            criterion = FocalLoss(weight=per_cls_weights,
                                  gamma=1).cuda(args.gpu)
        else:
            warnings.warn('Loss type is not listed')
            return

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, args,
              log_training, tf_writer)

        # evaluate on validation set
        acc1 = validate(val_loader, model, criterion, epoch, args, log_testing,
                        tf_writer)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        tf_writer.add_scalar('acc/test_top1_best', best_acc1, epoch)
        output_best = 'Best Prec@1: %.3f\n' % (best_acc1)
        print(output_best)
        log_testing.write(output_best + '\n')
        log_testing.flush()

        save_checkpoint(
            args, {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer': optimizer.state_dict(),
            }, is_best)
Esempio n. 7
0
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    # create model
    # print("=> creating model '{}'".format(args.arch))
    # num_classes = 100 if args.dataset == 'cifar100' else 10
    # use_norm = True if args.loss_type in ['LDAM'] else False
    # model = models.__dict__[args.arch](num_classes=num_classes, use_norm=use_norm)

    # TODO creat model Resnet50
    model = torchvision.models.resnet50(pretrained=False)

    if args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        model = torch.nn.DataParallel(model).cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cuda:0')
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            if args.gpu is not None:
                # best_acc1 may be from a checkpoint from a different GPU
                best_acc1 = best_acc1.to(args.gpu)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # Data loading code

    transform_train = transforms.Compose([
        # transforms.RandomCrop(32, padding=4),
        transforms.Resize(256),
        transforms.RandomResizedCrop(224),
        transforms.RandomRotation(90),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])

    transform_val = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])

    train_dataset = ImTinyImagenet(root='data/TinyImageNet/train',
                                   imb_type=args.imb_type,
                                   imb_factor=args.imb_factor,
                                   rand_number=args.rand_number,
                                   train=True,
                                   transform=transform_train)
    val_dataset = datasets.ImageFolder(root='data/TinyImageNet/val',
                                       transform=transform_val)
    # if args.dataset == 'cifar10':
    #     train_dataset = IMBALANCECIFAR10(root='./data/CIFAR10', imb_type=args.imb_type, imb_factor=args.imb_factor, rand_number=args.rand_number, train=True, download=True, transform=transform_train)
    #     val_dataset = datasets.CIFAR10(root='./data/CIFAR10', train=False, download=True, transform=transform_val)
    # elif args.dataset == 'cifar100':
    #     train_dataset = IMBALANCECIFAR100(root='./data/CIFAR100', imb_type=args.imb_type, imb_factor=args.imb_factor, rand_number=args.rand_number, train=True, download=True, transform=transform_train)
    #     val_dataset = datasets.CIFAR100(root='./data/CIFAR100', train=False, download=True, transform=transform_val)
    # else:
    #     warnings.warn('Dataset is not listed')
    #     return
    cls_num_list = train_dataset.get_cls_num_list()
    print('cls num list:')
    print(cls_num_list)
    args.cls_num_list = cls_num_list

    train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=100,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # init log for training
    log_training = open(
        os.path.join(args.root_log, args.store_name, 'log_train.csv'), 'w')
    log_testing = open(
        os.path.join(args.root_log, args.store_name, 'log_test.csv'), 'w')
    with open(os.path.join(args.root_log, args.store_name, 'args.txt'),
              'w') as f:
        f.write(str(args))
    tf_writer = SummaryWriter(
        log_dir=os.path.join(args.root_log, args.store_name))

    # TAG Init train rule
    if args.train_rule == 'None':
        train_sampler = None
        per_cls_weights = None
    elif args.train_rule == 'EffectiveNumber':
        train_sampler = None
        beta = 0.9999
        effective_num = 1.0 - np.power(beta, cls_num_list)
        per_cls_weights = (1.0 - beta) / np.array(effective_num)
        per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(
            cls_num_list)
        per_cls_weights = torch.FloatTensor(per_cls_weights).cuda(args.gpu)
    elif args.train_rule == 'ClassBlance':
        train_sampler = None
        per_cls_weights = 1.0 / np.array(cls_num_list)
        per_cls_weights = per_cls_weights / np.mean(per_cls_weights)
        per_cls_weights = torch.FloatTensor(per_cls_weights).cuda(args.gpu)
    elif args.train_rule == 'ClassBlanceV2':
        train_sampler = None
        per_cls_weights = 1.0 / np.power(np.array(cls_num_list), 0.25)
        per_cls_weights = per_cls_weights / np.mean(per_cls_weights)
        per_cls_weights = torch.FloatTensor(per_cls_weights).cuda(args.gpu)
    else:
        warnings.warn('Sample rule is not listed')

    # TAG Init loss
    if args.loss_type == 'CE':
        # criterion = nn.CrossEntropyLoss(weight=per_cls_weights).cuda(args.gpu)
        criterion = CEloss(weight=per_cls_weights).cuda(args.gpu)
    elif args.loss_type == 'LDAM':
        criterion = LDAMLoss(cls_num_list=cls_num_list,
                             max_m=0.5,
                             s=30,
                             weight=per_cls_weights).cuda(args.gpu)
    elif args.loss_type == 'Focal':
        criterion = FocalLoss(weight=per_cls_weights, gamma=1).cuda(args.gpu)
    elif args.loss_type == 'Seesaw':
        criterion = SeesawLoss(num_classes=num_classes).cuda(args.gpu)
    elif args.loss_type == 'GradSeesawLoss':
        criterion = GradSeesawLoss(num_classes=num_classes).cuda(args.gpu)
    elif args.loss_type == 'SoftSeesaw':
        criterion = SoftSeesawLoss(num_classes=num_classes,
                                   beta=args.beta).cuda(args.gpu)
    elif args.loss_type == 'SoftGradeSeesawLoss':
        criterion = SoftGradeSeesawLoss(num_classes=num_classes).cuda(args.gpu)
    elif args.loss_type == 'Seesaw_prior':
        criterion = SeesawLoss_prior(cls_num_list=cls_num_list).cuda(args.gpu)
    elif args.loss_type == 'GradSeesawLoss_prior':
        criterion = GradSeesawLoss_prior(cls_num_list=cls_num_list).cuda(
            args.gpu)
    elif args.loss_type == 'GHMc':
        criterion = GHMcLoss(bins=30, momentum=0.75,
                             use_sigmoid=True).cuda(args.gpu)
    elif args.loss_type == 'SoftmaxGHMc':
        criterion = SoftmaxGHMc(bins=30, momentum=0.75).cuda(args.gpu)
    elif args.loss_type == 'SoftmaxGHMcV2':
        criterion = SoftmaxGHMcV2(bins=30, momentum=0.75).cuda(args.gpu)
    elif args.loss_type == 'SoftmaxGHMcV3':
        criterion = SoftmaxGHMcV3(bins=30, momentum=0.75).cuda(args.gpu)
    elif args.loss_type == 'SeesawGHMc':
        criterion = SeesawGHMc(bins=30, momentum=0.75).cuda(args.gpu)
    elif args.loss_type == 'EQLv2':
        criterion = EQLv2(num_classes=num_classes).cuda(args.gpu)
    elif args.loss_type == 'EQL':
        criterion = EQLloss(cls_num_list=cls_num_list).cuda(args.gpu)
    elif args.loss_type == 'GHMSeesawV2':
        criterion = GHMSeesawV2(num_classes=num_classes).cuda(args.gpu)
    else:
        warnings.warn('Loss type is not listed')
        return

    valid_criterion = nn.CrossEntropyLoss().cuda(args.gpu)
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, args,
              log_training, tf_writer)
        # print(criterion.cls_num_list.transpose(1,0))

        # evaluate on validation set
        acc1 = validate(val_loader, model, valid_criterion, epoch, args,
                        log_testing, tf_writer)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        tf_writer.add_scalar('acc/test_top1_best', best_acc1, epoch)
        output_best = 'Best Prec@1: %.3f\n' % (best_acc1)
        print(output_best)
        log_testing.write(output_best + '\n')
        log_testing.flush()

        save_checkpoint(
            args, {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer': optimizer.state_dict(),
            }, is_best)
def main_worker(gpu, args):
    global best_acc1
    args.gpu = gpu

    if args.gpu is not None:
        print(f"Use GPU: {args.gpu} for training")

    print(f"===> Creating model '{args.arch}'")
    if args.dataset == 'cifar100':
        num_classes = 100
    elif args.dataset in {'cifar10', 'svhn'}:
        num_classes = 10
    else:
        raise NotImplementedError
    use_norm = True if args.loss_type == 'LDAM' else False
    model = models.__dict__[args.arch](num_classes=num_classes, use_norm=use_norm)

    if args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda()
    else:
        model = torch.nn.DataParallel(model).cuda()

    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum, weight_decay=args.weight_decay)

    mean = [0.4914, 0.4822, 0.4465] if args.dataset.startswith('cifar') else [.5, .5, .5]
    std = [0.2023, 0.1994, 0.2010] if args.dataset.startswith('cifar') else [.5, .5, .5]
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])
    transform_val = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    if args.dataset == 'cifar10':
        train_dataset = ImbalanceCIFAR10(
            root=args.data_path, imb_type=args.imb_type, imb_factor=args.imb_factor,
            rand_number=args.rand_number, train=True, download=True, transform=transform_train)
        val_dataset = datasets.CIFAR10(root=args.data_path,
                                       train=False, download=True, transform=transform_val)
        train_sampler = None
        if args.train_rule == 'Resample':
            train_sampler = ImbalancedDatasetSampler(train_dataset)
        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
            num_workers=args.workers, pin_memory=True, sampler=train_sampler)
        val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=100, shuffle=False,
                                                 num_workers=args.workers, pin_memory=True)
    elif args.dataset == 'cifar100':
        train_dataset = ImbalanceCIFAR100(
            root=args.data_path, imb_type=args.imb_type, imb_factor=args.imb_factor,
            rand_number=args.rand_number, train=True, download=True, transform=transform_train)
        val_dataset = datasets.CIFAR100(root=args.data_path,
                                        train=False, download=True, transform=transform_val)
        train_sampler = None
        if args.train_rule == 'Resample':
            train_sampler = ImbalancedDatasetSampler(train_dataset)
        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
            num_workers=args.workers, pin_memory=True, sampler=train_sampler)
        val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=100, shuffle=False,
                                                 num_workers=args.workers, pin_memory=True)
    elif args.dataset == 'svhn':
        train_dataset = ImbalanceSVHN(
            root=args.data_path, imb_type=args.imb_type, imb_factor=args.imb_factor,
            rand_number=args.rand_number, split='train', download=True, transform=transform_train)
        val_dataset = datasets.SVHN(root=args.data_path,
                                    split='test', download=True, transform=transform_val)
        train_sampler = None
        if args.train_rule == 'Resample':
            train_sampler = ImbalancedDatasetSampler(train_dataset)
        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
            num_workers=args.workers, pin_memory=True, sampler=train_sampler)
        val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=100, shuffle=False,
                                                 num_workers=args.workers, pin_memory=True)
    else:
        raise NotImplementedError(f"Dataset {args.dataset} is not supported!")

    # evaluate only
    if args.evaluate:
        assert args.resume, 'Specify a trained model using [args.resume]'
        checkpoint = torch.load(args.resume, map_location=torch.device(f'cuda:{str(args.gpu)}'))
        model.load_state_dict(checkpoint['state_dict'])
        print(f"===> Checkpoint '{args.resume}' loaded, testing...")
        validate(val_loader, model, nn.CrossEntropyLoss(), 0, args)
        return

    if args.resume:
        if os.path.isfile(args.resume):
            print(f"===> Loading checkpoint '{args.resume}'")
            checkpoint = torch.load(args.resume, map_location=torch.device(f'cuda:{str(args.gpu)}'))
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            if args.gpu is not None:
                best_acc1 = best_acc1.to(args.gpu)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print(f"===> Loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']})")
        else:
            raise ValueError(f"No checkpoint found at '{args.resume}'")

    # load self-supervised pre-trained model
    if args.pretrained_model:
        checkpoint = torch.load(args.pretrained_model, map_location=torch.device(f'cuda:{str(args.gpu)}'))
        if 'moco_ckpt' not in args.pretrained_model:
            from collections import OrderedDict
            new_state_dict = OrderedDict()
            for k, v in checkpoint['state_dict'].items():
                if 'linear' not in k and 'fc' not in k:
                    new_state_dict[k] = v
            model.load_state_dict(new_state_dict, strict=False)
            print(f'===> Pretrained weights found in total: [{len(list(new_state_dict.keys()))}]')
        else:
            # rename moco pre-trained keys
            state_dict = checkpoint['state_dict']
            for k in list(state_dict.keys()):
                # retain only encoder_q up to before the embedding layer
                if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
                    # remove prefix
                    state_dict[k[len("module.encoder_q."):]] = state_dict[k]
                # delete renamed or unused k
                del state_dict[k]
            msg = model.load_state_dict(state_dict, strict=False)
            if use_norm:
                assert set(msg.missing_keys) == {"fc.weight"}
            else:
                assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
        print(f'===> Pre-trained model loaded: {args.pretrained_model}')

    cudnn.benchmark = True

    if args.dataset.startswith(('cifar', 'svhn')):
        cls_num_list = train_dataset.get_cls_num_list()
        print('cls num list:')
        print(cls_num_list)
        args.cls_num_list = cls_num_list

    # init log for training
    log_training = open(os.path.join(args.root_log, args.store_name, 'log_train.csv'), 'w')
    log_testing = open(os.path.join(args.root_log, args.store_name, 'log_test.csv'), 'w')
    with open(os.path.join(args.root_log, args.store_name, 'args.txt'), 'w') as f:
        f.write(str(args))
    tf_writer = SummaryWriter(log_dir=os.path.join(args.root_log, args.store_name))

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args)

        if args.train_rule == 'Reweight':
            beta = 0.9999
            effective_num = 1.0 - np.power(beta, cls_num_list)
            per_cls_weights = (1.0 - beta) / np.array(effective_num)
            per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(cls_num_list)
            per_cls_weights = torch.FloatTensor(per_cls_weights).cuda(args.gpu)
        elif args.train_rule == 'DRW':
            idx = epoch // 160
            betas = [0, 0.9999]
            effective_num = 1.0 - np.power(betas[idx], cls_num_list)
            per_cls_weights = (1.0 - betas[idx]) / np.array(effective_num)
            per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(cls_num_list)
            per_cls_weights = torch.FloatTensor(per_cls_weights).cuda(args.gpu)
        else:
            per_cls_weights = None

        if args.loss_type == 'CE':
            criterion = nn.CrossEntropyLoss(weight=per_cls_weights).cuda(args.gpu)
        elif args.loss_type == 'LDAM':
            criterion = LDAMLoss(cls_num_list=cls_num_list, max_m=0.5, s=30, weight=per_cls_weights).cuda(args.gpu)
        elif args.loss_type == 'Focal':
            criterion = FocalLoss(weight=per_cls_weights, gamma=1).cuda(args.gpu)
        else:
            warnings.warn('Loss type is not listed')
            return

        train(train_loader, model, criterion, optimizer, epoch, args, log_training, tf_writer)
        acc1 = validate(val_loader, model, criterion, epoch, args, log_testing, tf_writer)

        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        tf_writer.add_scalar('acc/test_top1_best', best_acc1, epoch)
        output_best = 'Best Prec@1: %.3f\n' % best_acc1
        print(output_best)
        log_testing.write(output_best + '\n')
        log_testing.flush()

        save_checkpoint(args, {
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'best_acc1': best_acc1,
            'optimizer': optimizer.state_dict(),
        }, is_best)
Esempio n. 9
0

ssl_models = [
    "resnet18_ssl",
    "resnet50_ssl",
    "resnext50_32x4d_ssl",
    "resnext101_32x4d_ssl",
    "resnext101_32x8d_ssl",
    "resnext101_32x16d_ssl",
]

eff_models = ["tf_efficientnet_b3_ns,tf_efficientnet_b4_ns"]

loss_fn = {
    "cross_entropy": F.cross_entropy,
    "focal_loss": FocalLoss(),
    "label_smoothing": LabelSmoothingCrossEntropy(smoothing=0.3),
}


class CassavaModel(pl.LightningModule):
    def __init__(
        self,
        model_name: str = None,
        num_classes: int = None,
        loss_fn=F.cross_entropy,
        lr=1e-4,
        wd=1e-6,
    ):
        super().__init__()
Esempio n. 10
0
    def forward(self,
                input_ids,
                token_type_ids=None,
                attention_mask=None,
                labels=None,
                position_ids=None,
                head_mask=None,
                h_ids=None,
                h_attention_mask=None,
                p_ids=None,
                p_attention_mask=None,
                have_overlap=None,
                overlap_rate=None,
                subsequence=None,
                constituent=None,
                binary_labels=None):

        if self.hypothesis_only:
            outputs = self.bert(h_ids,
                                token_type_ids=None,
                                attention_mask=h_attention_mask)
            pooled_h = outputs[1]
            pooled_h_g = self.dropout(pooled_h)
            logits = self.h_classifier1(pooled_h_g)
            outputs = (logits, ) + outputs[2:]
        elif not self.hans_only:
            outputs = self.bert(input_ids, position_ids=position_ids, \
                                token_type_ids=token_type_ids, \
                                attention_mask=attention_mask, head_mask=head_mask)
            pooled_output = outputs[1]
            pooled_output = self.dropout(pooled_output)
            logits = self.classifier(pooled_output)
            # add hidden states and attention if they are here
            outputs = (logits, ) + outputs[2:]

        if self.hans:  # if both are correct.
            h_outputs = self.bert(h_ids,
                                  token_type_ids=None,
                                  attention_mask=h_attention_mask)

            if self.ensemble_training:  # also computes the h-only results.
                pooled_h_second = h_outputs[1]
                h_embd_second = grad_mul_const(pooled_h_second, 0.0)
                pooled_h_g_second = self.dropout(h_embd_second)
                h_logits_second = self.h_classifier1_second(pooled_h_g_second)
                h_outputs_second = (h_logits_second, ) + h_outputs[2:]

            h_matrix = h_outputs[0]
            h_matrix = grad_mul_const(h_matrix, 0.0)
            h_matrix = self.dropout(h_matrix)

            p_outputs = self.bert(p_ids,
                                  token_type_ids=None,
                                  attention_mask=p_attention_mask)
            p_matrix = p_outputs[0]
            p_matrix = grad_mul_const(p_matrix, 0.0)
            p_matrix = self.dropout(p_matrix)

            # compute similarity features.
            if self.hans_features:
                simialrity_score = get_word_similarity_new(h_matrix, p_matrix, self.similarity, \
                                                           h_attention_mask, p_attention_mask)

            # this is the default case.
            hans_h_inputs = torch.cat((simialrity_score, \
                                       have_overlap.view(-1, 1), overlap_rate.view(-1, 1), subsequence.view(-1, 1),
                                       constituent.view(-1, 1)), 1)

            if self.hans_features and len(self.length_features) != 0:
                length_features = get_length_features(p_attention_mask,
                                                      h_attention_mask,
                                                      self.length_features)
                hans_h_inputs = torch.cat((hans_h_inputs, length_features), 1)

            h_logits = self.h_classifier1(hans_h_inputs)
            h_outputs = (h_logits, ) + h_outputs[2:]

            if self.hans_only:
                logits = h_logits
                # overwrite outputs.
                outputs = h_outputs

        elif self.focal_loss or self.poe_loss or self.rubi:
            h_outputs = self.bert(h_ids,
                                  token_type_ids=None,
                                  attention_mask=h_attention_mask)
            pooled_h = h_outputs[1]
            h_embd = grad_mul_const(pooled_h, 0.0)
            pooled_h_g = self.dropout(h_embd)
            h_logits = self.h_classifier1(pooled_h_g)
            h_outputs = (h_logits, ) + h_outputs[2:]

        if labels is not None:
            if self.num_labels == 1:
                #  We are doing regression
                loss_fct = MSELoss()
                loss = loss_fct(logits.view(-1), labels.view(-1))
            else:
                if self.focal_loss:
                    loss_fct = FocalLoss(gamma=self.gamma_focal, \
                                         ensemble_training=self.ensemble_training,
                                         aggregate_ensemble=self.aggregate_ensemble)
                elif self.poe_loss:
                    loss_fct = POELoss(
                        ensemble_training=self.ensemble_training,
                        poe_alpha=self.poe_alpha)
                elif self.rubi:
                    loss_fct = RUBILoss(num_labels=self.num_labels)
                elif self.hans_only:
                    if self.weighted_bias_only and self.hans:
                        weights = torch.tensor([0.5, 1.0, 0.5]).cuda()
                        loss_fct = CrossEntropyLoss(weight=weights)
                else:
                    loss_fct = CrossEntropyLoss()

                if self.rubi or self.focal_loss or self.poe_loss:
                    if self.ensemble_training:
                        model_loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1), \
                                              h_logits.view(-1, self.num_labels),
                                              h_logits_second.view(-1, self.num_labels))
                    else:
                        model_loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1), \
                                              h_logits.view(-1, self.num_labels))

                    if self.weighted_bias_only and self.hans:
                        weights = torch.tensor([0.5, 1.0, 0.5]).cuda()
                        h_loss_fct = CrossEntropyLoss(weight=weights)
                        if self.ensemble_training:
                            h_loss_fct_second = CrossEntropyLoss()
                    else:
                        h_loss_fct = CrossEntropyLoss()

                    h_loss = h_loss_fct(h_logits.view(-1, self.num_labels),
                                        labels.view(-1))
                    if self.ensemble_training:
                        h_loss += h_loss_fct_second(
                            h_logits_second.view(-1, self.num_labels),
                            labels.view(-1))

                    loss = model_loss + self.lambda_h * h_loss
                else:
                    loss = loss_fct(logits.view(-1, self.num_labels),
                                    labels.view(-1))
            outputs = (loss, ) + outputs

        all_outputs = {}
        all_outputs["bert"] = outputs
        if self.rubi or self.focal_loss or self.poe_loss:
            all_outputs["h"] = h_outputs
        if self.ensemble_training:
            all_outputs["h_second"] = h_outputs_second
        return all_outputs  # (loss), logits, (hidden_states), (attentions)
Esempio n. 11
0
def main():
    global args
    args = parser.parse_args()
    with open(args.config) as f:
        config = yaml.load(f)

    for key in config:
        for k, v in config[key].items():
            setattr(args, k, v)

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    # Normalize the test set same as training set without augmentation
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    if args.imbalance == "regular":
        train_dataset = torchvision.datasets.CIFAR10(
            root='./data', train=True, download=True, transform=transform_train)
    else:
        train_dataset = IMBALANCECIFAR10(root='../part1-convnet/data',
                                         transform=transform_train,
                                     )
        cls_num_list = train_dataset.get_cls_num_list()
        if args.reweight:
            per_cls_weights = reweight(cls_num_list, beta=args.beta)
            if torch.cuda.is_available():
                per_cls_weights = per_cls_weights.cuda()
        else:
            per_cls_weights = None

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)

    test_dataset = torchvision.datasets.CIFAR10(
        root='./data', train=False, download=True, transform=transform_test)
    test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=100, shuffle=False, num_workers=2)

    if args.model == 'TwoLayerNet':
        model = TwoLayerNet(3072, 256, 10)
    elif args.model == 'VanillaCNN':
        model = VanillaCNN()
    elif args.model == 'MyModel':
        model = MyModel()
    elif args.model == 'ResNet-32':
        model = resnet32()
    print(model)
    if torch.cuda.is_available():
        model = model.cuda()

    if args.loss_type == "CE":
        criterion = nn.CrossEntropyLoss()
    else:
        criterion = FocalLoss(weight=per_cls_weights, gamma=1)

    optimizer = torch.optim.SGD(model.parameters(), args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.reg)
    best = 0.0
    best_cm = None
    best_model = None
    for epoch in range(args.epochs):
        adjust_learning_rate(optimizer, epoch, args)

        # train loop
        train(epoch, train_loader, model, optimizer, criterion)

        # validation loop
        acc, cm = validate(epoch, test_loader, model, criterion)

        if acc > best:
            best = acc
            best_cm = cm
            best_model = copy.deepcopy(model)

    print('Best Prec @1 Acccuracy: {:.4f}'.format(best))
    per_cls_acc = best_cm.diag().detach().numpy().tolist()
    for i, acc_i in enumerate(per_cls_acc):
        print("Accuracy of Class {}: {:.4f}".format(i, acc_i))

    if args.save_best:
        torch.save(best_model.state_dict(), './checkpoints/' + args.model.lower() + '.pth')
Esempio n. 12
0
# sanity check
# if len(XRayTrain_dataset.all_classes) != 15: # 15 is the unique number of diseases in this dataset
#     q('\nnumber of classes not equal to 15 !')

# a,b = train_dataset[0]
# print('\nwe are working with \nImages shape: {} and \nTarget shape: {}'.format( a.shape, b.shape))

# make models directory, where the models and the loss plots will be saved
if not os.path.exists(config.models_dir):
    os.mkdir(config.models_dir)

# define the loss function
if args.loss_func == 'FocalLoss':  # by default
    from losses import FocalLoss
    loss_fn = FocalLoss(device=device, gamma=2.).to(device)
elif args.loss_func == 'BCE':
    loss_fn = nn.BCEWithLogitsLoss().to(device)

# define the learning rate
lr = args.lr

if not args.test:  # training

    # initialize the model if not args.resume
    if not args.resume:
        print('\ntraining from scratch')
        # import pretrained model
        #model = models.resnet50(pretrained=True) # pretrained = False bydefault
        #model = torch.hub.load('pytorch/vision:v0.9.0', 'resnet50', pretrained=False)
        model = torch.hub.load('pytorch/vision:v0.9.0',