def main(): import argparse parser = argparse.ArgumentParser(description="Pytorch Image CNN training from Configure Files") parser.add_argument('--config_file', required=True, help="This scripts only accepts parameters from Json files") input_args = parser.parse_args() config_file = input_args.config_file args = parse_config(config_file) if args.name is None: args.name = get_stem(config_file) torch.set_default_tensor_type('torch.FloatTensor') best_prec1 = 0 args.script_name = get_stem(__file__) current_time_str = get_date_str() # if args.resume is None: if args.save_directory is None: save_directory = get_dir(os.path.join(project_root, 'ckpts2', '{:s}'.format(args.name), '{:s}-{:s}'.format(args.ID, current_time_str))) else: save_directory = get_dir(os.path.join(project_root, 'ckpts2', args.save_directory)) # else: # save_directory = os.path.dirname(args.resume) print("Save to {}".format(save_directory)) log_file = os.path.join(save_directory, 'log-{0}.txt'.format(current_time_str)) logger = log_utils.get_logger(log_file) log_utils.print_config(vars(args), logger) print_func = logger.info print_func('ConfigFile: {}'.format(config_file)) args.log_file = log_file if args.device: os.environ["CUDA_VISIBLE_DEVICES"]=args.device if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') if args.gpu is not None: warnings.warn('You have chosen a specific GPU. This will completely ' 'disable data parallelism.') args.distributed = args.world_size > 1 if args.distributed: dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size) if args.pretrained: print_func("=> using pre-trained model '{}'".format(args.arch)) visual_model = models.__dict__[args.arch](pretrained=True, num_classes=args.num_classes) else: print_func("=> creating model '{}'".format(args.arch)) visual_model = models.__dict__[args.arch](pretrained=False, num_classes=args.num_classes) if args.freeze: visual_model = CNN_utils.freeze_all_except_fc(visual_model) if os.path.isfile(args.text_ckpt): print_func("=> loading checkpoint '{}'".format(args.text_ckpt)) text_data = torch.load(args.text_ckpt, map_location=lambda storage, loc:storage) text_model = TextCNN(text_data['args_model']) # load_state_dict(text_model, text_data['state_dict']) text_model.load_state_dict(text_data['state_dict'], strict=True) text_model.eval() print_func("=> loaded checkpoint '{}' for text classification" .format(args.text_ckpt)) args.vocab_size = text_data['args_model'].vocab_size else: print_func("=> no checkpoint found at '{}'".format(args.text_ckpt)) return args.tag2clsidx = text_data['args_data'].tag2idx args.vocab_size = len(args.tag2clsidx) args.text_embed = loadpickle(args.text_embed) args.idx2tag = loadpickle(args.idx2tag)['idx2tag'] if args.gpu is not None: visual_model = visual_model.cuda(args.gpu) text_model = text_model.cuda((args.gpu)) elif args.distributed: visual_model.cuda() visual_model = torch.nn.parallel.DistributedDataParallel(visual_model) else: if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): visual_model.features = torch.nn.DataParallel(visual_model.features) visual_model.cuda() else: visual_model = torch.nn.DataParallel(visual_model).cuda() text_model = torch.nn.DataParallel(text_model).cuda() criterion = nn.CrossEntropyLoss(ignore_index=-1).cuda(args.gpu) optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, visual_model.parameters()), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) if args.lr_schedule: print_func("Using scheduled learning rate") scheduler = lr_scheduler.MultiStepLR( optimizer, [int(i) for i in args.lr_schedule.split(',')], gamma=0.1) else: scheduler = lr_scheduler.ReduceLROnPlateau( optimizer, 'min', patience=args.lr_patience) # optimizer = torch.optim.SGD(model.parameters(), args.lr, # momentum=args.momentum, # weight_decay=args.weight_decay) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print_func("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) import collections if isinstance(checkpoint, collections.OrderedDict): load_state_dict(visual_model, checkpoint) else: load_state_dict(visual_model, checkpoint['state_dict']) print_func("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch'])) else: print_func("=> no checkpoint found at '{}'".format(args.resume)) cudnn.benchmark = True model_total_params = sum(p.numel() for p in visual_model.parameters()) model_grad_params = sum(p.numel() for p in visual_model.parameters() if p.requires_grad) print_func("Total Parameters: {0}\t Gradient Parameters: {1}".format(model_total_params, model_grad_params)) # Data loading code val_dataset = get_instance(custom_datasets, '{0}'.format(args.valloader), args) if val_dataset is None: val_loader = None else: val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, collate_fn=none_collate) if args.evaluate: print_func('Validation Only') validate(val_loader, visual_model, criterion, args, print_func) return else: train_dataset = get_instance(custom_datasets, '{0}'.format(args.trainloader), args) if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) else: train_sampler = None train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=none_collate) for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) if args.lr_schedule: # CNN_utils.adjust_learning_rate(optimizer, epoch, args.lr) scheduler.step() current_lr = optimizer.param_groups[0]['lr'] print_func("Epoch: [{}], learning rate: {}".format(epoch, current_lr)) # train for one epoch train(train_loader, visual_model, text_model, criterion, optimizer, epoch, args, print_func) # evaluate on validation set if val_loader: prec1, val_loss = validate(val_loader, visual_model, criterion, args, print_func) else: prec1 = 0 val_loss = 0 # remember best prec@1 and save checkpoint is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) CNN_utils.save_checkpoint({ 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': visual_model.state_dict(), 'best_prec1': best_prec1, 'optimizer' : optimizer.state_dict(), }, is_best, file_directory=save_directory, epoch=epoch) if not args.lr_schedule: scheduler.step(val_loss)
def main(): import argparse parser = argparse.ArgumentParser( description="Pytorch Image CNN training from Configure Files") parser.add_argument( '--config_file', required=True, help="This scripts only accepts parameters from Json files") input_args = parser.parse_args() config_file = input_args.config_file args = parse_config(config_file) if args.name is None: args.name = get_stem(config_file) torch.set_default_tensor_type('torch.FloatTensor') args.script_name = get_stem(__file__) current_time_str = get_date_str() if args.resume is None: if args.save_directory is None: save_directory = get_dir( os.path.join(project_root, 'ckpts', '{:s}'.format(args.name), '{:s}-{:s}'.format(args.ID, current_time_str))) else: save_directory = get_dir( os.path.join(project_root, 'ckpts', args.save_directory)) else: if args.save_directory is None: save_directory = os.path.dirname(args.resume) else: current_time_str = get_date_str() save_directory = get_dir( os.path.join(args.save_directory, '{:s}'.format(args.name), '{:s}-{:s}'.format(args.ID, current_time_str))) print("Save to {}".format(save_directory)) log_file = os.path.join(save_directory, 'log-{0}.txt'.format(current_time_str)) logger = log_utils.get_logger(log_file) log_utils.print_config(vars(args), logger) print_func = logger.info print_func('ConfigFile: {}'.format(config_file)) args.log_file = log_file if args.device: os.environ["CUDA_VISIBLE_DEVICES"] = args.device if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') if args.gpu is not None: warnings.warn('You have chosen a specific GPU. This will completely ' 'disable data parallelism.') #args.distributed = args.world_size > 1 args.distributed = False if args.distributed: dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size) num_datasets = args.num_datasets # model_list = [None for x in range(num_datasets)] # for j in range(num_datasets): if args.pretrained: print_func("=> using pre-trained model '{}'".format(args.arch)) model = models.__dict__[args.arch](pretrained=True, num_classes=args.class_len) else: print_func("=> creating model '{}'".format(args.arch)) model = models.__dict__[args.arch](pretrained=False, num_classes=args.class_len) if args.freeze: model = CNN_utils.freeze_all_except_fc(model) if args.gpu is not None: model = model.cuda(args.gpu) elif args.distributed: model.cuda() model = torch.nn.parallel.DistributedDataParallel(model) else: if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): model.features = torch.nn.DataParallel(model.features) model.cuda() else: model = torch.nn.DataParallel(model).cuda() # # define loss function (criterion) and optimizer # # # Update: here # # config = {'loss': {'type': 'simpleCrossEntropyLoss', 'args': {'param': None}}} # # criterion = get_instance(loss_funcs, 'loss', config) # # criterion = criterion.cuda(args.gpu) # criterion = nn.CrossEntropyLoss(ignore_index=-1).cuda(args.gpu) # criterion = MclassCrossEntropyLoss().cuda(args.gpu) # params = list() # for j in range(num_datasets): # params += list(model_list[j].parameters()) optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) if args.lr_schedule: print_func("Using scheduled learning rate") scheduler = lr_scheduler.MultiStepLR( optimizer, [int(i) for i in args.lr_schedule.split(',')], gamma=0.1) else: scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=args.lr_patience) ''' if args.resume: if os.path.isfile(args.resume): print_func("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) import collections if not args.evaluate: if isinstance(checkpoint, collections.OrderedDict): load_state_dict(model, checkpoint, exclude_layers=['fc.weight', 'fc.bias']) else: load_state_dict(model, checkpoint['state_dict'], exclude_layers=['module.fc.weight', 'module.fc.bias']) print_func("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch'])) else: if isinstance(checkpoint, collections.OrderedDict): load_state_dict(model, checkpoint, strict=True) else: load_state_dict(model, checkpoint['state_dict'], strict=True) print_func("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch'])) else: print_func("=> no checkpoint found at '{}'".format(args.resume)) return ''' cudnn.benchmark = True model_total_params = sum(p.numel() for p in model.parameters()) model_grad_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print_func("Total Parameters: {0}\t Gradient Parameters: {1}".format( model_total_params, model_grad_params)) # Data loading code val_loaders = [None for x in range(num_datasets)] test_loaders = [None for x in range(num_datasets)] train_loaders = [None for x in range(num_datasets)] num_iter = 0 for k in range(num_datasets): args.ind = k val_dataset = get_instance(custom_datasets, args.val_loader, args) if val_dataset is None or k == num_datasets - 1: val_loaders[args.ind] = None else: val_loaders[args.ind] = torch.utils.data.DataLoader( val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, collate_fn=none_collate) if hasattr(args, 'test_files') and hasattr(args, 'test_loader'): test_dataset = get_instance(custom_datasets, args.test_loader, args) test_loaders[args.ind] = torch.utils.data.DataLoader( test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, collate_fn=none_collate) else: # test_dataset = None test_loaders[args.ind] = None #if args.evaluate: # validate(test_loaders[args.ind], model_list[k], criterion, args, print_func) # return # if not args.evaluate: #else: # train_samplers = [None for x in range(num_datasets)] # train_dataset = get_instance(custom_datasets, args.train_loader, args) # # if args.distributed: # train_samplers[args.ind] = torch.utils.data.distributed.DistributedSampler(train_dataset) # else: # train_samplers[args.ind] = None # # train_loaders[args.ind] = torch.utils.data.DataLoader( # train_dataset, batch_size=args.batch_size, shuffle=(train_samplers[args.ind] is None), # num_workers=args.workers, pin_memory=True, sampler=train_samplers[args.ind], collate_fn=none_collate) if not args.evaluate: #else: # train_samplers = [None for x in range(num_datasets)] train_dataset = get_instance(custom_datasets, args.train_loader, args) num_iter = max(num_iter, len(train_dataset.samples)) if args.distributed: train_samplers = torch.utils.data.distributed.DistributedSampler( train_dataset) else: train_samplers = None train_loaders[args.ind] = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=train_samplers is None, num_workers=args.workers, pin_memory=True, sampler=train_samplers, collate_fn=none_collate) setattr(args, 'num_iter', num_iter) # TRAINING best_prec1 = [-1 for _ in range(num_datasets)] is_best = [None for _ in range(num_datasets)] setattr(args, 'lam', 0.5) start_data_time = time.time() train_loads_iters = [iter(train_loaders[x]) for x in range(num_datasets)] print_func("Loaded data in {:.3f} s".format(time.time() - start_data_time)) for epoch in range(args.start_epoch, args.epochs): if args.distributed: for x in range(num_datasets): train_samplers[x].set_epoch(epoch) if args.lr_schedule: # CNN_utils.adjust_learning_rate(optimizer, epoch, args.lr) scheduler.step() current_lr = optimizer.param_groups[0]['lr'] print_func("Epoch: [{}], learning rate: {}".format(epoch, current_lr)) # train for one epoch train(train_loads_iters, train_loaders, model, criterion, optimizer, epoch, args, print_func) # evaluate and save val_prec1 = [None for x in range(num_datasets)] test_prec1 = [None for x in range(num_datasets)] for j in range(num_datasets): # if j != args.ind: # load_state_dict(model_list[j], model_list[args.ind].state_dict()) # evaluate on validation set if val_loaders[j]: val_prec1[j], _ = validate(val_loaders[j], model, criterion, args, print_func, j) else: val_prec1[j] = 0 # remember best prec@1 and save checkpoint is_best[j] = val_prec1[j] > best_prec1[j] best_prec1[j] = max(val_prec1[j], best_prec1[j]) if is_best[j]: save_ind = j else: save_ind = "#" CNN_utils.save_checkpoint( { 'epoch': epoch, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_prec1': best_prec1[j], 'optimizer': optimizer.state_dict(), }, is_best[j], file_directory=save_directory, epoch=epoch, save_best_only=args.save_best_only, ind=save_ind) test_prec1[j], _ = validate(test_loaders[j], model, criterion, args, print_func, j, phase='Test') print_func("Val precisions: {}".format(val_prec1)) print_func("Test precisions: {}".format(test_prec1))
def main(): import argparse parser = argparse.ArgumentParser( description="Pytorch Image CNN training from Configure Files") parser.add_argument( '--config_file', required=True, help="This scripts only accepts parameters from Json files") input_args = parser.parse_args() config_file = input_args.config_file args = parse_config(config_file) if args.name is None: args.name = get_stem(config_file) torch.set_default_tensor_type('torch.FloatTensor') best_prec1 = 0 args.script_name = get_stem(__file__) current_time_str = get_date_str() if args.save_directory is None: save_directory = get_dir( os.path.join(project_root, args.ckpts_dir, '{:s}'.format(args.name), '{:s}-{:s}'.format(args.ID, current_time_str))) else: save_directory = get_dir( os.path.join(project_root, args.ckpts_dir, args.save_directory)) print("Save to {}".format(save_directory)) log_file = os.path.join(save_directory, 'log-{0}.txt'.format(current_time_str)) logger = log_utils.get_logger(log_file) log_utils.print_config(vars(args), logger) print_func = logger.info print_func('ConfigFile: {}'.format(config_file)) args.log_file = log_file if args.device: os.environ["CUDA_VISIBLE_DEVICES"] = args.device if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') if args.gpu is not None: warnings.warn('You have chosen a specific GPU. This will completely ' 'disable data parallelism.') args.distributed = args.world_size > 1 if args.distributed: dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size) if args.pretrained: print_func("=> using pre-trained model '{}'".format(args.arch)) model = models.__dict__[args.arch](pretrained=True, num_classes=args.num_classes) else: print_func("=> creating model '{}'".format(args.arch)) model = models.__dict__[args.arch](pretrained=False, num_classes=args.num_classes) if args.freeze: model = CNN_utils.freeze_all_except_fc(model) if args.gpu is not None: model = model.cuda(args.gpu) elif args.distributed: model.cuda() model = torch.nn.parallel.DistributedDataParallel(model) else: print_func( 'Please only specify one GPU since we are working in batch size 1 model' ) return if args.resume: if os.path.isfile(args.resume): print_func("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) import collections if not args.evaluate: if isinstance(checkpoint, collections.OrderedDict): load_state_dict(model, checkpoint, exclude_layers=['fc.weight', 'fc.bias']) else: load_state_dict( model, checkpoint['state_dict'], exclude_layers=['module.fc.weight', 'module.fc.bias']) print_func("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: if isinstance(checkpoint, collections.OrderedDict): load_state_dict(model, checkpoint, strict=True) else: load_state_dict(model, checkpoint['state_dict'], strict=True) print_func("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print_func("=> no checkpoint found at '{}'".format(args.resume)) return else: print_func( "=> This script is for fine-tuning only, please double check '{}'". format(args.resume)) print_func("Now using randomly initialized parameters!") cudnn.benchmark = True model_total_params = sum(p.numel() for p in model.parameters()) model_grad_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print_func("Total Parameters: {0}\t Gradient Parameters: {1}".format( model_total_params, model_grad_params)) # Data loading code # val_dataset = get_instance(custom_datasets, '{0}'.format(args.valloader), args) from PyUtils.pickle_utils import loadpickle from torchvision.datasets.folder import default_loader val_dataset = loadpickle(args.val_file) image_directory = args.data_dir from CNNs.datasets.multilabel import get_val_simple_transform val_transform = get_val_simple_transform() import tqdm import numpy as np if args.individual_feat: feature_save_directory = get_dir( os.path.join(save_directory, 'individual-features')) created_paths = set() else: data_dict = {} feature_save_directory = os.path.join(save_directory, 'feature.pkl') model.eval() for s_data in tqdm.tqdm(val_dataset, desc="Extracting Features"): if s_data is None: continue image_path = os.path.join(image_directory, s_data[0]) try: input_image = default_loader(image_path) except: print("WARN: {} Problematic!, Skip!".format(image_path)) continue input_image = val_transform(input_image) if args.gpu is not None: input_image = input_image.cuda(args.gpu, non_blocking=True) output = model(input_image.unsqueeze_(0)) output = output.cpu().data.numpy() # image_rel_path = os.path.join(*(s_image_name.split(os.sep)[-int(args.rel_path_depth):])) if args.individual_feat: if image_directory in created_paths: np.save( os.path.join(feature_save_directory, '{}.npy'.format(s_data[0])), output) else: get_dir(os.path.join(feature_save_directory, image_directory)) np.save( os.path.join(feature_save_directory, '{}.npy'.format(s_data[0])), output) created_paths.add(image_directory) else: data_dict[s_data[0]] = output # image_name = os.path.basename(s_image_name) # # if args.individual_feat: # # image_name = os.path.basename(s_image_name) # # np.save(os.path.join(feature_save_directory, '{}.npy'.format(image_name)), output) # # created_paths.add(image_directory) # else: # data_dict[get_stem(image_name)] = output if args.individual_feat: print_func("Done") else: from PyUtils.pickle_utils import save2pickle print_func("Saving to a single big file!") save2pickle(feature_save_directory, data_dict) print_func("Done")
def main(): best_prec1 = 0 args = parser.parser.parse_args() config_file = None if args.config is not None: config_file = args.config args = parse_config(args.config) script_name_stem = get_stem(__file__) current_time_str = get_date_str() if args.resume is None: if args.save_directory is None: save_directory = get_dir( os.path.join( project_root, 'ckpts', '{:s}'.format(args.name), '{:s}-{:s}-{:s}'.format(script_name_stem, args.ID, current_time_str))) else: save_directory = get_dir( os.path.join(project_root, 'ckpts', args.save_directory)) else: save_directory = os.path.dirname(args.resume) print("Save to {}".format(save_directory)) log_file = os.path.join(save_directory, 'log-{0}.txt'.format(current_time_str)) logger = log_utils.get_logger(log_file) log_utils.print_config(vars(args), logger) print_func = logger.info if config_file is not None: print_func('ConfigFile: {}'.format(config_file)) else: print_func('ConfigFile: None, params from argparse') args.log_file = log_file if args.device: os.environ["CUDA_VISIBLE_DEVICES"] = args.device if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') if args.gpu is not None: warnings.warn('You have chosen a specific GPU. This will completely ' 'disable data parallelism.') args.distributed = args.world_size > 1 if args.distributed: dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size) # create model if args.arch == 'resnet50otherinits': print_func("=> using pre-trained model '{}'".format(args.arch)) model = models.__dict__[args.arch](pretrained=True, num_classes=args.num_classes, param_name=args.paramname) elif args.arch == 'resnet50_feature_extractor': print_func("=> using pre-trained model '{}' to LOAD FEATURES".format( args.arch)) model = models.__dict__[args.arch](pretrained=True, num_classes=args.num_classes, param_name=args.paramname) else: if args.pretrained: print_func("=> using pre-trained model '{}'".format(args.arch)) model = models.__dict__[args.arch](pretrained=True, num_classes=args.num_classes) else: print_func("=> creating model '{}'".format(args.arch)) model = models.__dict__[args.arch](pretrained=False, num_classes=args.num_classes) if args.freeze: model = CNN_utils.freeze_all_except_fc(model) if args.gpu is not None: model = model.cuda(args.gpu) elif args.distributed: model.cuda() model = torch.nn.parallel.DistributedDataParallel(model) else: if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): model.features = torch.nn.DataParallel(model.features) model.cuda() else: model = torch.nn.DataParallel(model).cuda() # define loss function (criterion) and optimizer # # Update: here # config = {'loss': {'type': 'simpleCrossEntropyLoss', 'args': {'param': None}}} # criterion = get_instance(loss_funcs, 'loss', config) # criterion = criterion.cuda(args.gpu) criterion = nn.CrossEntropyLoss(ignore_index=-1).cuda(args.gpu) optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) if args.lr_schedule: print_func("Using scheduled learning rate") scheduler = lr_scheduler.MultiStepLR( optimizer, [int(i) for i in args.lr_schedule.split(',')], gamma=0.1) else: scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=args.lr_patience) # optimizer = torch.optim.SGD(model.parameters(), args.lr, # momentum=args.momentum, # weight_decay=args.weight_decay) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print_func("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print_func("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print_func("=> no checkpoint found at '{}'".format(args.resume)) cudnn.benchmark = True model_total_params = sum(p.numel() for p in model.parameters()) model_grad_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print_func("Total Parameters: {0}\t Gradient Parameters: {1}".format( model_total_params, model_grad_params)) # Data loading code val_dataset = get_instance(custom_datasets, '{0}_val'.format(args.dataset.name), args, **args.dataset.args) if val_dataset is None: val_loader = None else: val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, collate_fn=none_collate) if args.evaluate: validate(val_loader, model, criterion, args, print_func) return else: train_dataset = get_instance(custom_datasets, '{0}_train'.format(args.dataset.name), args, **args.dataset.args) if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset) else: train_sampler = None train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=none_collate) for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) if args.lr_schedule: # CNN_utils.adjust_learning_rate(optimizer, epoch, args.lr) scheduler.step() current_lr = optimizer.param_groups[0]['lr'] print_func("Epoch: [{}], learning rate: {}".format(epoch, current_lr)) # train for one epoch train(train_loader, model, criterion, optimizer, epoch, args, print_func) # evaluate on validation set if val_loader: prec1, val_loss = validate(val_loader, model, criterion, args, print_func) else: prec1 = 0 val_loss = 0 # remember best prec@1 and save checkpoint is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) CNN_utils.save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, 'optimizer': optimizer.state_dict(), }, is_best, file_directory=save_directory, epoch=epoch) if not args.lr_schedule: scheduler.step(val_loss)