class Loader(object): def __init__(self, args): self.args = args if self.args.dataset == 'cityscapes': self.nclass = 19 self.model = AutoDeeplab(num_classes=self.nclass, num_layers=12, filter_multiplier=self.args.filter_multiplier, block_multiplier=args.block_multiplier, step=args.step) # Using cuda if args.cuda: if (torch.cuda.device_count() > 1 or args.load_parallel): self.model = torch.nn.DataParallel(self.model.cuda()) patch_replication_callback(self.model) self.model = self.model.cuda() print('cuda finished') # Resuming checkpoint self.best_pred = 0.0 if args.resume is not None: if not os.path.isfile(args.resume): raise RuntimeError("=> no checkpoint found at '{}'".format( args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] # if the weights are wrapped in module object we have to clean it if args.clean_module: self.model.load_state_dict(checkpoint['state_dict']) state_dict = checkpoint['state_dict'] new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] # remove 'module.' of dataparallel new_state_dict[name] = v self.model.load_state_dict(new_state_dict) else: if (torch.cuda.device_count() > 1 or args.load_parallel): self.model.module.load_state_dict(checkpoint['state_dict']) else: self.model.load_state_dict(checkpoint['state_dict']) self.decoder = Decoder(self.model.alphas, self.model.bottom_betas, self.model.betas8, self.model.betas16, self.model.top_betas, args.block_multiplier, args.step) def retreive_alphas_betas(self): return self.model.alphas, self.model.bottom_betas, self.model.betas8, self.model.betas16, self.model.top_betas def decode_architecture(self): paths, paths_space = self.decoder.viterbi_decode() return paths, paths_space def decode_cell(self): genotype = self.decoder.genotype_decode() return genotype
type=int, default=32, help='F in paper') parser.add_argument('--steps', type=int, default=5, help='B in paper') parser.add_argument('--down_sample_level', type=int, default=8, help='s in paper') return parser.parse_args() if __name__ == "__main__": import torch import time args = obtain_default_search_args() criterion = torch.nn.CrossEntropyLoss(ignore_index=255) model = AutoDeeplab(num_classes=19, num_layers=12, criterion=criterion, filter_multiplier=args.filter_multiplier, block_multiplier=args.block_multiplier, step=args.step) model = nn.DataParallel(model).cuda() # torch.save(model.state_dict(), './checkpoint.pts.tar') checkpoint = torch.load('./checkpoint.pts.tar') st = time.time() # copy_state_dict(model.state_dict(), checkpoint) model.load_state_dict(checkpoint) et = time.time() print(et - st)