def __init__(self, args): self.args = args # Define Saver self.saver = Saver(args) self.saver.save_experiment_config() # Define Tensorboard Summary self.summary = TensorboardSummary(self.saver.experiment_dir) self.writer = self.summary.create_summary() self.use_amp = True if (APEX_AVAILABLE and args.use_amp) else False self.opt_level = args.opt_level kwargs = { 'num_workers': args.workers, 'pin_memory': True, 'drop_last': True } self.train_loaderA, self.train_loaderB, self.val_loader, self.test_loader, self.nclass = make_data_loader( args, **kwargs) if args.use_balanced_weights: classes_weights_path = os.path.join( Path.db_root_dir(args.dataset), args.dataset + '_classes_weights.npy') if os.path.isfile(classes_weights_path): weight = np.load(classes_weights_path) else: raise NotImplementedError #if so, which trainloader to use? # weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass) weight = torch.from_numpy(weight.astype(np.float32)) else: weight = None self.criterion = SegmentationLosses( weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type) # Define network model = AutoDeeplab(self.nclass, 12, self.criterion, self.args.filter_multiplier, self.args.block_multiplier, self.args.step) optimizer = torch.optim.SGD(model.weight_parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) self.model, self.optimizer = model, optimizer self.architect_optimizer = torch.optim.Adam( self.model.arch_parameters(), lr=args.arch_lr, betas=(0.9, 0.999), weight_decay=args.arch_weight_decay) # Define Evaluator self.evaluator = Evaluator(self.nclass) # Define lr scheduler self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, len(self.train_loaderA), min_lr=args.min_lr) # TODO: Figure out if len(self.train_loader) should be devided by two ? in other module as well # Using cuda if args.cuda: self.model = self.model.cuda() # mixed precision if self.use_amp and args.cuda: keep_batchnorm_fp32 = True if (self.opt_level == 'O2' or self.opt_level == 'O3') else None # fix for current pytorch version with opt_level 'O1' if self.opt_level == 'O1' and torch.__version__ < '1.3': for module in self.model.modules(): if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): # Hack to fix BN fprop without affine transformation if module.weight is None: module.weight = torch.nn.Parameter( torch.ones(module.running_var.shape, dtype=module.running_var.dtype, device=module.running_var.device), requires_grad=False) if module.bias is None: module.bias = torch.nn.Parameter( torch.zeros(module.running_var.shape, dtype=module.running_var.dtype, device=module.running_var.device), requires_grad=False) # print(keep_batchnorm_fp32) self.model, [self.optimizer, self.architect_optimizer] = amp.initialize( self.model, [self.optimizer, self.architect_optimizer], opt_level=self.opt_level, keep_batchnorm_fp32=keep_batchnorm_fp32, loss_scale="dynamic") print('cuda finished') # Using data parallel if args.cuda and len(self.args.gpu_ids) > 1: if self.opt_level == 'O2' or self.opt_level == 'O3': print( 'currently cannot run with nn.DataParallel and optimization level', self.opt_level) self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids) patch_replication_callback(self.model) print('training on multiple-GPUs') #checkpoint = torch.load(args.resume) #print('about to load state_dict') #self.model.load_state_dict(checkpoint['state_dict']) #print('model loaded') #sys.exit() # Resuming checkpoint self.best_pred = 0.0 if args.resume is not None: if not os.path.isfile(args.resume): raise RuntimeError("=> no checkpoint found at '{}'".format( args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] # if the weights are wrapped in module object we have to clean it if args.clean_module: self.model.load_state_dict(checkpoint['state_dict']) state_dict = checkpoint['state_dict'] new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] # remove 'module.' of dataparallel new_state_dict[name] = v # self.model.load_state_dict(new_state_dict) copy_state_dict(self.model.state_dict(), new_state_dict) else: if torch.cuda.device_count() > 1 or args.load_parallel: # self.model.module.load_state_dict(checkpoint['state_dict']) copy_state_dict(self.model.module.state_dict(), checkpoint['state_dict']) else: # self.model.load_state_dict(checkpoint['state_dict']) copy_state_dict(self.model.state_dict(), checkpoint['state_dict']) if not args.ft: # self.optimizer.load_state_dict(checkpoint['optimizer']) copy_state_dict(self.optimizer.state_dict(), checkpoint['optimizer']) self.best_pred = checkpoint['best_pred'] print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) # Clear start epoch if fine-tuning if args.ft: args.start_epoch = 0
def __init__(self, args): self.args = args """ Define Saver """ self.saver = Saver(args) self.saver.save_experiment_config() """ Define Tensorboard Summary """ self.summary = TensorboardSummary(self.saver.experiment_dir) self.writer = self.summary.create_summary() """ Define Dataloader """ kwargs = { 'num_workers': args.workers, 'pin_memory': True, 'drop_last': True } self.train_loader, self.val_loader, _, self.nclass = make_data_loader( args, **kwargs) self.criterion = nn.L1Loss() if args.network == 'searched-dense': cell_path = os.path.join(args.saved_arch_path, 'autodeeplab', 'genotype.npy') cell_arch = np.load(cell_path) if self.args.C == 2: C_index = [5] network_arch = [1, 2, 2, 2, 3, 2, 2, 1, 1, 1, 1, 2] low_level_layer = 0 elif self.args.C == 3: C_index = [3, 7] network_arch = [1, 2, 3, 2, 2, 3, 2, 3, 2, 3, 2, 3] low_level_layer = 0 elif self.args.C == 4: C_index = [2, 5, 8] network_arch = [1, 2, 3, 3, 2, 3, 3, 3, 3, 3, 2, 2] low_level_layer = 0 model = ADD(network_arch, C_index, cell_arch, self.nclass, args, low_level_layer) elif args.network.startswith('autodeeplab'): network_arch = [0, 0, 0, 1, 2, 1, 2, 2, 3, 3, 2, 1] cell_path = os.path.join(args.saved_arch_path, 'autodeeplab', 'genotype.npy') cell_arch = np.load(cell_path) low_level_layer = 2 if self.args.C == 2: C_index = [5] elif self.args.C == 3: C_index = [3, 7] elif self.args.C == 4: C_index = [2, 5, 8] if args.network == 'autodeeplab-dense': model = ADD(network_arch, C_index, cell_arch, self.nclass, args, low_level_layer) elif args.network == 'autodeeplab-baseline': model = Baselin_Model(network_arch, C_index, cell_arch, self.nclass, args, low_level_layer) self.edm = EDM().cuda() optimizer = torch.optim.Adam(self.edm.parameters(), lr=args.lr) self.model, self.optimizer = model, optimizer if args.cuda: self.model = self.model.cuda() """ Resuming checkpoint """ if args.resume is not None: if not os.path.isfile(args.resume): raise RuntimeError("=> no checkpoint found at '{}'".format( args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] """ if the weights are wrapped in module object we have to clean it """ if args.clean_module: self.model.load_state_dict(checkpoint['state_dict']) state_dict = checkpoint['state_dict'] new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] # remove 'module.' of dataparallel new_state_dict[name] = v copy_state_dict(self.model.state_dict(), new_state_dict) else: if (torch.cuda.device_count() > 1): copy_state_dict(self.model.module.state_dict(), checkpoint['state_dict']) else: copy_state_dict(self.model.state_dict(), checkpoint['state_dict']) if os.path.isfile('feature.npy'): train_feature = np.load('feature.npy') train_entropy = np.load('entropy.npy') train_set = TensorDataset( torch.tensor(train_feature), torch.tensor(train_entropy, dtype=torch.float)) train_set = DataLoader(train_set, batch_size=self.args.train_batch, shuffle=True, pin_memory=True) self.train_set = train_set else: self.make_data(self.args.train_batch)
def __init__(self, args): self.args = args # Define Saver self.saver = Saver(args) self.saver.save_experiment_config() # Define Tensorboard Summary self.summary = TensorboardSummary(self.saver.experiment_dir) self.writer = self.summary.create_summary() kwargs = {'num_workers': args.workers, 'pin_memory': True, 'drop_last':True} self.train_loaderA, self.train_loaderB, self.val_loader, self.test_loader = make_data_loader(args, **kwargs) # Define network model = AutoStereo(maxdisp = self.args.max_disp, Fea_Layers=self.args.fea_num_layers, Fea_Filter=self.args.fea_filter_multiplier, Fea_Block=self.args.fea_block_multiplier, Fea_Step=self.args.fea_step, Mat_Layers=self.args.mat_num_layers, Mat_Filter=self.args.mat_filter_multiplier, Mat_Block=self.args.mat_block_multiplier, Mat_Step=self.args.mat_step) optimizer_F = torch.optim.SGD( model.feature.weight_parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay ) optimizer_M = torch.optim.SGD( model.matching.weight_parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay ) self.model, self.optimizer_F, self.optimizer_M = model, optimizer_F, optimizer_M self.architect_optimizer_F = torch.optim.Adam(self.model.feature.arch_parameters(), lr=args.arch_lr, betas=(0.9, 0.999), weight_decay=args.arch_weight_decay) self.architect_optimizer_M = torch.optim.Adam(self.model.matching.arch_parameters(), lr=args.arch_lr, betas=(0.9, 0.999), weight_decay=args.arch_weight_decay) # Define lr scheduler self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, len(self.train_loaderA), min_lr=args.min_lr) # Using cuda if args.cuda: self.model = torch.nn.DataParallel(self.model).cuda() # Resuming checkpoint self.best_pred = 100.0 if args.resume is not None: if not os.path.isfile(args.resume): raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] # if the weights are wrapped in module object we have to clean it if args.clean_module: self.model.load_state_dict(checkpoint['state_dict']) state_dict = checkpoint['state_dict'] new_state_dict = OrderedDict() for k, v in state_dict.items(): if k.find('module') != -1: print(1) pdb.set_trace() name = k[7:] # remove 'module.' of dataparallel new_state_dict[name] = v # self.model.load_state_dict(new_state_dict) pdb.set_trace() copy_state_dict(self.model.state_dict(), new_state_dict) else: if torch.cuda.device_count() > 1:#or args.load_parallel: # self.model.module.load_state_dict(checkpoint['state_dict']) copy_state_dict(self.model.module.state_dict(), checkpoint['state_dict']) else: # self.model.load_state_dict(checkpoint['state_dict']) copy_state_dict(self.model.module.state_dict(), checkpoint['state_dict']) if not args.ft: # self.optimizer.load_state_dict(checkpoint['optimizer']) copy_state_dict(self.optimizer_M.state_dict(), checkpoint['optimizer_M']) copy_state_dict(self.optimizer_F.state_dict(), checkpoint['optimizer_F']) self.best_pred = checkpoint['best_pred'] print("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch'])) # Clear start epoch if fine-tuning if args.ft: args.start_epoch = 0 print('Total number of model parameters : {}'.format(sum([p.data.nelement() for p in self.model.parameters()]))) print('Number of Feature Net parameters: {}'.format(sum([p.data.nelement() for p in self.model.module.feature.parameters()]))) print('Number of Matching Net parameters: {}'.format(sum([p.data.nelement() for p in self.model.module.matching.parameters()])))
def __init__(self, args): self.args = args """ Define Saver """ self.saver = Saver(args) self.saver.save_experiment_config() """ Define Tensorboard Summary """ self.summary = TensorboardSummary(self.saver.experiment_dir) self.writer = self.summary.create_summary() self.use_amp = True if (APEX_AVAILABLE and args.use_amp) else False self.opt_level = args.opt_level kwargs = { 'num_workers': args.workers, 'pin_memory': True, 'drop_last': True, 'drop_last': True } self.train_loaderA, self.train_loaderB, self.val_loader, self.test_loader, self.nclass = make_data_loader( args, **kwargs) if args.use_balanced_weights: classes_weights_path = os.path.join( Path.db_root_dir(args.dataset), args.dataset + '_classes_weights.npy') if os.path.isfile(classes_weights_path): weight = np.load(classes_weights_path) else: """ if so, which trainloader to use? """ weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass) weight = torch.from_numpy(weight.astype(np.float32)) else: weight = None self.criterion = nn.CrossEntropyLoss(weight=weight, ignore_index=255).cuda() """ Define network """ if self.args.network == 'supernet': model = Model_search(self.nclass, 12, self.args, exit_layer=5) elif self.args.network == 'layer_supernet': cell_path = os.path.join(args.saved_arch_path, 'autodeeplab', 'genotype.npy') cell_arch = np.load(cell_path) model = Model_layer_search(self.nclass, 12, self.args, exit_layer=5, alphas=cell_arch) optimizer = torch.optim.SGD(model.weight_parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) self.model, self.optimizer = model, optimizer self.architect_optimizer = torch.optim.Adam( self.model.arch_parameters(), lr=args.arch_lr, betas=(0.9, 0.999), weight_decay=args.arch_weight_decay) """ Define Evaluator """ self.evaluator_1 = Evaluator(self.nclass) self.evaluator_2 = Evaluator(self.nclass) """ Define lr scheduler """ self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, len(self.train_loaderA), min_lr=args.min_lr) """ Using cuda """ if args.cuda: self.model = self.model.cuda() """ mixed precision """ if self.use_amp and args.cuda: keep_batchnorm_fp32 = True if (self.opt_level == 'O2' or self.opt_level == 'O3') else None """ fix for current pytorch version with opt_level 'O1' """ if self.opt_level == 'O1' and torch.__version__ < '1.3': for module in self.model.modules(): if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): """ Hack to fix BN fprop without affine transformation """ if module.weight is None: module.weight = torch.nn.Parameter( torch.ones(module.running_var.shape, dtype=module.running_var.dtype, device=module.running_var.device), requires_grad=False) if module.bias is None: module.bias = torch.nn.Parameter( torch.zeros(module.running_var.shape, dtype=module.running_var.dtype, device=module.running_var.device), requires_grad=False) # print(keep_batchnorm_fp32) self.model, [self.optimizer, self.architect_optimizer] = amp.initialize( self.model, [self.optimizer, self.architect_optimizer], opt_level=self.opt_level, keep_batchnorm_fp32=keep_batchnorm_fp32, loss_scale="dynamic") print('cuda finished') """ Using data parallel""" if args.cuda and len(self.args.gpu_ids) > 1: if self.opt_level == 'O2' or self.opt_level == 'O3': print( 'currently cannot run with nn.DataParallel and optimization level', self.opt_level) self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids) patch_replication_callback(self.model) print('training on multiple-GPUs') """ Resuming checkpoint """ self.best_pred = 0.0 if args.resume is not None: if not os.path.isfile(args.resume): raise RuntimeError("=> no checkpoint found at '{}'".format( args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] """ if the weights are wrapped in module object we have to clean it """ if args.clean_module: self.model.load_state_dict(checkpoint['state_dict']) state_dict = checkpoint['state_dict'] new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] # remove 'module.' of dataparallel new_state_dict[name] = v copy_state_dict(self.model.state_dict(), new_state_dict) else: if (torch.cuda.device_count() > 1): copy_state_dict(self.model.module.state_dict(), checkpoint['state_dict']) else: copy_state_dict(self.model.state_dict(), checkpoint['state_dict'])
def __init__(self, args): self.args = args #Define Saver self.saver = Saver(args) #call saver function in which it is created a file #where informations train (like dataset,epoch..) are saved self.saver.save_experiment_config() kwargs = { 'num_workers': args.workers, 'pin_memory': True, 'drop_last': True } self.train_loaderA, self.train_loaderB, self.val_loader, self.test_loader, self.nclass = make_data_loader( args, **kwargs) ##TODO: capire cosa è weight = None self.criterion = SegmentationLosses( weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type) model = AutoDeeplab(self.nclass, 10, self.criterion, self.args.filter_multiplier, self.args.block_multiplier, self.args.step) optimizer = torch.optim.SGD(model.weight_parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) self.model, self.optimizer = model, optimizer self.architect_optimizer = torch.optim.Adam( self.model.arch_parameters(), lr=args.arch_lr, betas=(0.9, 0.999), weight_decay=args.arch_weight_decay) # Define Evaluator ##TODO:capire cosa è self.evaluator = Evaluator(self.nclass) # Define lr scheduler self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, len(self.train_loaderA), min_lr=args.min_lr) self.model = self.model.cuda() self.best_pred = 0.0 if args.resume is not None: if not os.path.isfile(args.resume): raise RuntimeError("=> no checkpoint found at '{}'".format( args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] if args.clean_module: self.model.load_state_dict(checkpoint['state_dict']) state_dict = checkpoint['state_dict'] new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] # remove 'module.' of dataparallel new_state_dict[name] = v # self.model.load_state_dict(new_state_dict) copy_state_dict(self.model.state_dict(), new_state_dict) else: # self.model.load_state_dict(checkpoint['state_dict']) copy_state_dict(self.model.state_dict(), checkpoint['state_dict']) if not args.ft: # self.optimizer.load_state_dict(checkpoint['optimizer']) copy_state_dict(self.optimizer.state_dict(), checkpoint['optimizer']) self.best_pred = checkpoint['best_pred'] print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) if args.resume is not None: self.best_pred = checkpoint['best_pred'] print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) # Clear start epoch if fine-tuning if args.ft: args.start_epoch = 0
def __init__(self, args): self.args = args """ Define Saver """ self.saver = Saver(args) self.saver.save_experiment_config() """ Define Tensorboard Summary """ self.summary = TensorboardSummary(self.saver.experiment_dir) self.writer = self.summary.create_summary() self.use_amp = self.args.use_amp self.opt_level = self.args.opt_level """ Define Dataloader """ kwargs = {'num_workers': args.workers, 'pin_memory': True, 'drop_last': True} self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs) if args.network == 'searched_dense': """ 40_5e_lr_38_31.91 """ # cell_path_1 = os.path.join(args.saved_arch_path, '40_5e_38_lr', 'genotype_1.npy') # cell_path_2 = os.path.join(args.saved_arch_path, '40_5e_38_lr','genotype_2.npy') # cell_arch_1 = np.load(cell_path_1) # cell_arch_2 = np.load(cell_path_2) # network_arch = [1, 2, 3, 2, 3, 2, 2, 1, 2, 1, 1, 2] cell_path = os.path.join(args.saved_arch_path, 'autodeeplab', 'genotype.npy') cell_arch = np.load(cell_path) network_arch = [0, 1, 2, 3, 2, 2, 2, 2, 1, 2, 3, 2] low_level_layer = 0 model = Model_2(network_arch, cell_arch, self.nclass, args, low_level_layer) elif args.network == 'searched_baseline': cell_path_1 = os.path.join(args.saved_arch_path, 'searched_baseline', 'genotype_1.npy') cell_path_2 = os.path.join(args.saved_arch_path, 'searched_baseline','genotype_2.npy') cell_arch_1 = np.load(cell_path_1) cell_arch_2 = np.load(cell_path_2) network_arch = [0, 1, 2, 2, 3, 2, 2, 1, 2, 1, 1, 2] low_level_layer = 1 model = Model_2_baseline(network_arch, cell_arch, self.nclass, args, low_level_layer) elif args.network.startswith('autodeeplab'): network_arch = [0, 0, 0, 1, 2, 1, 2, 2, 3, 3, 2, 1] cell_path = os.path.join(args.saved_arch_path, 'autodeeplab', 'genotype.npy') cell_arch = np.load(cell_path) low_level_layer = 2 if args.network == 'autodeeplab-dense': model = Model_2(network_arch, cell_arch, self.nclass, args, low_level_layer) elif args.network == 'autodeeplab-baseline': model = Model_2_baseline(network_arch, cell_arch, self.nclass, args, low_level_layer) """ Define Optimizer """ optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.nesterov) """ Define Criterion """ """ whether to use class balanced weights """ if args.use_balanced_weights: classes_weights_path = os.path.join(Path.db_root_dir(args.dataset), args.dataset + '_classes_weights.npy') if os.path.isfile(classes_weights_path): weight = np.load(classes_weights_path) else: weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass) weight = torch.from_numpy(weight.astype(np.float32)) else: weight = None self.criterion = nn.CrossEntropyLoss(weight=weight, ignore_index=255).cuda() self.model, self.optimizer = model, optimizer """ Define Evaluator """ self.evaluator_1 = Evaluator(self.nclass) self.evaluator_2 = Evaluator(self.nclass) """ Define lr scheduler """ self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, len(self.train_loader)) if args.cuda: self.model = self.model.cuda() """ mixed precision """ if self.use_amp and args.cuda: keep_batchnorm_fp32 = True if (self.opt_level == 'O2' or self.opt_level == 'O3') else None """ fix for current pytorch version with opt_level 'O1' """ if self.opt_level == 'O1' and torch.__version__ < '1.3': for module in self.model.modules(): if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) or isinstance(module, SynchronizedBatchNorm2d): """ Hack to fix BN fprop without affine transformation """ if module.weight is None: module.weight = torch.nn.Parameter( torch.ones(module.running_var.shape, dtype=module.running_var.dtype, device=module.running_var.device), requires_grad=False) if module.bias is None: module.bias = torch.nn.Parameter( torch.zeros(module.running_var.shape, dtype=module.running_var.dtype, device=module.running_var.device), requires_grad=False) # print(keep_batchnorm_fp32) self.model, self.optimizer = amp.initialize( self.model, self.optimizer, opt_level=self.opt_level, keep_batchnorm_fp32=keep_batchnorm_fp32, loss_scale="dynamic") if args.cuda and len(self.args.gpu_ids) >1: if self.opt_level == 'O2' or self.opt_level == 'O3': print('currently cannot run with nn.DataParallel and optimization level', self.opt_level) self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids) patch_replication_callback(self.model) print('training on multiple-GPUs') """ Resuming checkpoint """ self.best_pred = 0.0 if args.resume is not None: if not os.path.isfile(args.resume): raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] """ if the weights are wrapped in module object we have to clean it """ if args.clean_module: self.model.load_state_dict(checkpoint['state_dict']) state_dict = checkpoint['state_dict'] new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] # remove 'module.' of dataparallel new_state_dict[name] = v copy_state_dict(self.model.state_dict(), new_state_dict) else: if (torch.cuda.device_count() > 1): copy_state_dict(self.model.module.state_dict(), checkpoint['state_dict']) else: copy_state_dict(self.model.state_dict(), checkpoint['state_dict']) if not args.ft: self.optimizer.load_state_dict(checkpoint['optimizer']) self.best_pred = checkpoint['best_pred'] print("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch'])) """ Clear start epoch if fine-tuning """ if args.ft: args.start_epoch = 0