def main(): parser = argparse.ArgumentParser(description='PyTorch FixMatch Training') parser.add_argument('--gpu-id', default='0', type=int, help='id(s) for CUDA_VISIBLE_DEVICES') parser.add_argument('--num-workers', type=int, default=4, help='number of workers') parser.add_argument('--dataset', default='filtered1500', type=str, choices=['cifar10', 'cifar100', 'filtered1500'], help='dataset name') parser.add_argument('--num-labeled', type=int, default=2000, help='number of labeled data') parser.add_argument("--expand-labels", action="store_true", help="expand labels to fit eval steps") parser.add_argument('--arch', default='wideresnet', type=str, choices=['wideresnet', 'resnext'], help='dataset name') parser.add_argument('--total-steps', default=2 ** 20, type=int, help='number of total steps to run') parser.add_argument('--eval-step', default=1024, type=int, help='number of eval steps to run') parser.add_argument('--start-epoch', default=0, type=int, help='manual epoch number (useful on restarts)') parser.add_argument('--batch-size', default=64, type=int, help='train batchsize') parser.add_argument('--lr', '--learning-rate', default=0.03, type=float, help='initial learning rate') parser.add_argument('--warmup', default=0, type=float, help='warmup epochs (unlabeled data based)') parser.add_argument('--wdecay', default=5e-4, type=float, help='weight decay') parser.add_argument('--nesterov', action='store_true', default=True, help='use nesterov momentum') parser.add_argument('--use-ema', action='store_true', default=True, help='use EMA model') parser.add_argument('--ema-decay', default=0.999, type=float, help='EMA decay rate') parser.add_argument('--mu', default=7, type=int, help='coefficient of unlabeled batch size') parser.add_argument('--lambda-u', default=1, type=float, help='coefficient of unlabeled loss') parser.add_argument('--T', default=1, type=float, help='pseudo label temperature') parser.add_argument('--threshold', default=0.95, type=float, help='pseudo label threshold') parser.add_argument('--out', default='result', help='directory to output the result') parser.add_argument('--resume', default='', type=str, help='path to latest checkpoint (default: none)') parser.add_argument('--seed', default=None, type=int, help="random seed") parser.add_argument("--amp", action="store_true", help="use 16-bit (mixed) precision through NVIDIA apex AMP") parser.add_argument("--opt_level", type=str, default="O1", help="apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." "See details at https://nvidia.github.io/apex/amp.html") parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") parser.add_argument('--no-progress', action='store_true', help="don't use progress bar") args = parser.parse_args() global best_acc def create_model(args): if args.arch == 'wideresnet': import models.wideresnet as models model = models.build_wideresnet(depth=args.model_depth, widen_factor=args.model_width, dropout=0, num_classes=args.num_classes) elif args.arch == 'resnext': import models.resnext as models model = models.build_resnext(cardinality=args.model_cardinality, depth=args.model_depth, width=args.model_width, num_classes=args.num_classes) logger.info("Total params: {:.2f}M".format( sum(p.numel() for p in model.parameters()) / 1e6)) return model if args.local_rank == -1: device = torch.device('cuda', args.gpu_id) args.world_size = 1 args.n_gpu = torch.cuda.device_count() else: torch.cuda.set_device(args.local_rank) device = torch.device('cuda', args.local_rank) torch.distributed.init_process_group(backend='nccl') args.world_size = torch.distributed.get_world_size() args.n_gpu = 1 args.device = device logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) logger.warning( f"Process rank: {args.local_rank}, " f"device: {args.device}, " f"n_gpu: {args.n_gpu}, " f"distributed training: {bool(args.local_rank != -1)}, " f"16-bits training: {args.amp}", ) logger.info(dict(args._get_kwargs())) if args.seed is not None: set_seed(args) if args.local_rank in [-1, 0]: os.makedirs(args.out, exist_ok=True) args.writer = SummaryWriter(args.out) if args.dataset == 'cifar10': args.num_classes = 10 if args.arch == 'wideresnet': args.model_depth = 28 args.model_width = 2 elif args.arch == 'resnext': args.model_cardinality = 4 args.model_depth = 28 args.model_width = 4 elif args.dataset == 'cifar100': args.num_classes = 100 if args.arch == 'wideresnet': args.model_depth = 28 args.model_width = 8 elif args.arch == 'resnext': args.model_cardinality = 8 args.model_depth = 29 args.model_width = 64 elif args.dataset == 'filtered1500': args.num_classes = 8 if args.arch == 'wideresnet': args.model_depth = 28 args.model_width = 2 elif args.arch == 'resnext': args.model_cardinality = 4 args.model_depth = 28 args.model_width = 4 if args.local_rank not in [-1, 0]: torch.distributed.barrier() labeled_dataset, unlabeled_dataset, test_dataset = DATASET_GETTERS[args.dataset]( args, './data') if args.local_rank == 0: torch.distributed.barrier() train_sampler = RandomSampler if args.local_rank == -1 else DistributedSampler labeled_trainloader = DataLoader( labeled_dataset, sampler=train_sampler(labeled_dataset), batch_size=args.batch_size, num_workers=args.num_workers, drop_last=True) unlabeled_trainloader = DataLoader( unlabeled_dataset, sampler=train_sampler(unlabeled_dataset), batch_size=args.batch_size * args.mu, num_workers=args.num_workers, drop_last=True) test_loader = DataLoader( test_dataset, sampler=SequentialSampler(test_dataset), batch_size=args.batch_size, num_workers=args.num_workers) if args.local_rank not in [-1, 0]: torch.distributed.barrier() model = create_model(args) if args.local_rank == 0: torch.distributed.barrier() model.to(args.device) no_decay = ['bias', 'bn'] grouped_parameters = [ {'params': [p for n, p in model.named_parameters() if not any( nd in n for nd in no_decay)], 'weight_decay': args.wdecay}, {'params': [p for n, p in model.named_parameters() if any( nd in n for nd in no_decay)], 'weight_decay': 0.0} ] optimizer = optim.SGD(grouped_parameters, lr=args.lr, momentum=0.9, nesterov=args.nesterov) args.epochs = math.ceil(args.total_steps / args.eval_step) # args.epochs = 2 scheduler = get_cosine_schedule_with_warmup( optimizer, args.warmup, args.total_steps) if args.use_ema: from models.ema import ModelEMA ema_model = ModelEMA(args, model, args.ema_decay) args.start_epoch = 0 if args.resume: logger.info("==> Resuming from checkpoint..") assert os.path.isfile( args.resume), "Error: no checkpoint directory found!" args.out = os.path.dirname(args.resume) checkpoint = torch.load(args.resume) best_acc = checkpoint['best_acc'] args.start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) if args.use_ema: ema_model.ema.load_state_dict(checkpoint['ema_state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) scheduler.load_state_dict(checkpoint['scheduler']) if args.amp: from apex import amp model, optimizer = amp.initialize( model, optimizer, opt_level=args.opt_level) if args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) logger.info("***** Running training *****") logger.info(f" Task = {args.dataset}@{args.num_labeled}") logger.info(f" Num Epochs = {args.epochs}") logger.info(f" Batch size per GPU = {args.batch_size}") logger.info( f" Total train batch size = {args.batch_size * args.world_size}") logger.info(f" Total optimization steps = {args.total_steps}") model.zero_grad() train(args, labeled_trainloader, unlabeled_trainloader, test_loader, model, optimizer, ema_model, scheduler)
def main(dataset): def create_model(ema=False): print("=> creating {ema}model ".format( ema='EMA ' if ema else '')) #model = TCN(input_size=1, output_size=args.n_class, num_channels=[32] *8, kernel_size=2) model = ResNet50(args.n_class) model.cuda() return model global best_prec1 # Data print('==> Preparing tcga data') transform_train = transforms.Compose([ GaussianNoise(), ToTensor(), ]) transform_val = transforms.Compose([ ToTensor(), ]) train_labeled_set, train_unlabeled_set, train_unlabeled_set2, val_set, test_set = get_datasets('./data',args.index, args.n_labeled, args.n_class, transform_train=transform_train,transform_strong=transform_train, transform_val=transform_val,withGeo=args.geo) train_labeled_loader = data.DataLoader(train_labeled_set, batch_size=args.batch_size, num_workers=args.num_workers,shuffle=True,drop_last=True) train_unlabeled_loader = data.DataLoader(train_unlabeled_set, batch_size=args.batch_size*args.unsup_ratio, shuffle=True, num_workers=args.num_workers, drop_last=True) train_unlabeled_loader2 = data.DataLoader(train_unlabeled_set2, batch_size=args.batch_size*args.unsup_ratio, shuffle=False, num_workers=args.num_workers) test_loader = data.DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers) model = create_model() ema_model = ModelEMA(args, model, args.ema_decay) criterion = nn.CrossEntropyLoss().cuda() no_decay = ['bias', 'bn'] grouped_parameters = [ {'params': [p for n, p in model.named_parameters() if not any( nd in n for nd in no_decay)], 'weight_decay': args.weight_decay}, {'params': [p for n, p in model.named_parameters() if any( nd in n for nd in no_decay)], 'weight_decay': 0.0} ] optimizer = optim.SGD(grouped_parameters, lr=args.lr, momentum=0.9, nesterov=True) totals = args.epochs * args.epoch_iteration warmup_step = args.warmup_step * args.epoch_iteration scheduler = WarmupCosineSchedule(optimizer, warmup_step, totals) all_labels = np.zeros([len(train_unlabeled_set), args.n_class]) # optionally resume from a checkpoint title = dataset if args.resume: assert os.path.isfile(args.resume), "=> no checkpoint found at '{}'".format(args.resume) print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) ema_model.ema.load_state_dict(checkpoint['ema_state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) scheduler.load_state_dict(checkpoint['scheduler']) print("Evaluating the model:") test_loss, test_acc = validate(test_loader, ema_model.ema, criterion) print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch'])) logger = Logger(os.path.join(args.out_path, '%s_log_%d.txt' % (dataset, args.n_labeled)), title=title, resume=True) logger.append([args.start_epoch, 0, 0, test_loss, test_acc]) else: logger = Logger(os.path.join(args.out_path, '%s_log_%d.txt' % (dataset, args.n_labeled)), title=title) logger.set_names(['epoch', 'Train_class_loss', 'Train_consistency_loss', 'Test_Loss', 'Test_Acc.']) for epoch in range(args.start_epoch, args.epochs): start_time = time.time() # train for one epoch class_loss, cons_loss = train_semi(train_labeled_loader, train_unlabeled_loader, model, ema_model, optimizer, all_labels, epoch, scheduler) all_labels = get_u_label(model, train_unlabeled_loader2, all_labels) print("--- training epoch in %s seconds ---" % (time.time() - start_time)) if args.evaluation_epochs and (epoch + 1) % args.evaluation_epochs == 0: start_time = time.time() print("Evaluating the model:") test_loss, test_acc = validate(test_loader, model, criterion) print("--- validation in %s seconds ---" % (time.time() - start_time)) logger.append([epoch, class_loss, cons_loss, test_loss, test_acc]) print("Evaluating the EMA model:") ema_test_loss, ema_test_acc = validate(test_loader, ema_model.ema, criterion) print("--- validation in %s seconds ---" % (time.time() - start_time)) logger.append([epoch, class_loss, cons_loss, ema_test_loss, ema_test_acc]) if args.checkpoint_epochs and (epoch + 1) % args.checkpoint_epochs == 0: save_checkpoint( '%s_%d' % (dataset, args.n_labeled), { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'ema_state_dict': ema_model.ema.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), }, 'checkpoint_path', epoch + 1)