ax0.legend() ax1.legend() fig.savefig(os.path.join('./model', name, 'train.jpg')) ###################################################################### # Finetuning the convnet # ---------------------- # # Load a pretrainied model and reset final fully connected layer. # if opt.views == 2: model = two_view_net(len(class_names), droprate=opt.droprate, stride=opt.stride, pool=opt.pool, share_weight=opt.share) elif opt.views == 3: model = three_view_net(len(class_names), droprate=opt.droprate, stride=opt.stride, pool=opt.pool, share_weight=opt.share) opt.nclasses = len(class_names) print(model) # For resume: if start_epoch >= 40: opt.lr = opt.lr * 0.1
def load_network(name, opt): # Load config dirname = os.path.join('./model',name) last_model_name = os.path.basename(get_model_list(dirname, 'net')) epoch = last_model_name.split('_')[1] epoch = epoch.split('.')[0] if not epoch=='last': epoch = int(epoch) config_path = os.path.join(dirname,'opts.yaml') with open(config_path, 'r') as stream: config = yaml.load(stream) opt.name = config['name'] opt.data_dir = config['data_dir'] opt.train_all = config['train_all'] opt.droprate = config['droprate'] opt.color_jitter = config['color_jitter'] opt.batchsize = config['batchsize'] opt.h = config['h'] opt.w = config['w'] opt.share = config['share'] opt.stride = config['stride'] if 'pool' in config: opt.pool = config['pool'] if 'h' in config: opt.h = config['h'] opt.w = config['w'] if 'gpu_ids' in config: opt.gpu_ids = config['gpu_ids'] opt.erasing_p = config['erasing_p'] opt.lr = config['lr'] opt.nclasses = config['nclasses'] opt.erasing_p = config['erasing_p'] opt.use_dense = config['use_dense'] opt.fp16 = config['fp16'] opt.views = config['views'] if opt.use_dense: model = ft_net_dense(opt.nclasses, opt.droprate, opt.stride, None, opt.pool) if opt.PCB: model = PCB(opt.nclasses) if opt.views == 2: model = two_view_net(opt.nclasses, opt.droprate, stride = opt.stride, pool = opt.pool, share_weight = opt.share) elif opt.views == 3: model = three_view_net(opt.nclasses, opt.droprate, stride = opt.stride, pool = opt.pool, share_weight = opt.share) if 'use_vgg16' in config: opt.use_vgg16 = config['use_vgg16'] if opt.views == 2: model = two_view_net(opt.nclasses, opt.droprate, stride = opt.stride, pool = opt.pool, share_weight = opt.share, VGG16 = opt.use_vgg16) elif opt.views == 3: model = three_view_net(opt.nclasses, opt.droprate, stride = opt.stride, pool = opt.pool, share_weight = opt.share, VGG16 = opt.use_vgg16) # load model if isinstance(epoch, int): save_filename = 'net_%03d.pth'% epoch else: save_filename = 'net_%s.pth'% epoch save_path = os.path.join('./model',name,save_filename) print('Load the model from %s'%save_path) network = model network.load_state_dict(torch.load(save_path)) return network, opt, epoch