def init_network(): model_path = "" model_file_path = os.path.join(model_path, "latest.tar") model = torch.load(model_file_path) project_net = MLP(768, 768, [128, 64]) project_net.load_state_dict(model["project_net"]) bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") bert_model = BertModel.from_pretrained("bert-base-uncased", output_hidden_states=True) bert_model.eval() project_net.eval()
def main(cfg): device = set_env(cfg) logging.info('Loading the dataset.') _, criterion = get_criterion(cfg.optimization.criterion) train_dataloader, val_dataloader = get_dataloader(cfg) model = MLP(**cfg.network).to(device) model.load_state_dict( torch.load( f'{cfg.editing.model_path}/{cfg.data.dataset}/best_model.pth')) model.eval() logging.info(f'Constructing model on the {device}:{cfg.CUDA_DEVICE}.') logging.info(model) cfg.data.max_value = torch.tensor([ 64, 64, 2, 2, 64, 2, 2, 2, 64, 64, 2, 2, 2, 2, 64, 64, 64, 2, 2, 2, 2, 2, 64, 64, 64, 64, 64 ]) / 2 * cfg.editing.max_value_alpha cfg.data.min_value = torch.tensor([ 16, 16, 2, 2, 16, 2, 2, 2, 16, 16, 2, 2, 2, 2, 16, 16, 16, 2, 2, 2, 2, 2, 16, 16, 16, 16, 16 ]) / 2 cfg.data.normalized_max_value = (cfg.data.max_value - cfg.data.input_mean) / cfg.data.input_std cfg.data.normalized_min_value = (cfg.data.min_value - cfg.data.input_mean) / cfg.data.input_std data = torch.tensor([[ 64, 64, 2, 2, 64, 2, 2, 2, 64, 64, 2, 2, 2, 2, 64, 64, 64, 2, 2, 2, 2, 2, 64, 64, 64, 64, 64 ]]) # data = torch.tensor([[40, 40, 1, 4, 8, 1, 1, 2, 104, 64, 1, 3, 1, 1, 56, 32, 88, 3, 1, 1, 3, 3, 8, 64, 128, 128, 16]]) # seg # data = torch.tensor([[88, 128, 1, 1, 128, 1, 1, 4, 120, 32, 2, 1, 2, 4, 128, 128, 128, 1, 1, 1, 1, 1, 32, 128, 8, 8, 128]]) # cls # data = torch.tensor([[8, 8, 1, 1, 88, 1, 1, 1, 8, 8, 1, 1, 1, 1, 64, 128, 128, 1, 1, 1, 1, 4, 80, 128, 128, 48, 128]]) # video # data = torch.tensor([[120, 48, 1, 1, 24, 1, 3, 4, 80, 128, 1, 1, 3, 1, 96, 8, 128, 1, 1, 1, 2, 1, 40, 80, 40, 96, 112]]) # 3ddet normalized_data = normalize(data, cfg, device) denormalized_data = data[0].numpy() rounded_data = denormalized_data.copy() # original_metrics = predicted_metrics(model, normalized_data, cfg) flops, params = net2flops(data[0].int().cpu().numpy().tolist(), device) edit_net_set = list() for iter in tqdm(range(cfg.editing.iters)): optimizer = torch.optim.SGD([normalized_data], lr=cfg.editing.lr) optimizer.zero_grad() model.zero_grad() pred = model(normalized_data)[0] main_record = model( normalize(torch.Tensor(rounded_data).unsqueeze(0), cfg, device))[0][0] main_metric = pred[0] main_metric_target = main_metric.clone().detach( ) + cfg.editing.per_step_increase loss = criterion(main_metric, main_metric_target) net_dict = { 'rounded_net': rounded_data, 'continuous_net': denormalized_data, 'predicted_metrics': pred.detach().cpu().numpy().tolist() * cfg.data.output_std + cfg.data.output_mean, 'main_metric': main_record.detach().cpu().item() * cfg.data.output_std[0] + cfg.data.output_mean[0], 'flops': flops, 'params': params } edit_net_set.append(net_dict) print(net_dict) if cfg.editing.use_flops: flops = pred[-2] flops_target = flops.clone().detach( ) - cfg.editing.per_flops_decrease loss = loss + cfg.editing.alpha * criterion(flops, flops_target) loss.backward() optimizer.step() for i in range(normalized_data.shape[1]): if normalized_data[0][i] > cfg.data.normalized_max_value[i]: normalized_data[0][i] = cfg.data.normalized_max_value[i] if normalized_data[0][i] < cfg.data.normalized_min_value[i]: normalized_data[0][i] = cfg.data.normalized_min_value[i] normalized_data = normalized_data.detach().clone() normalized_data.requires_grad = True denormalized_data = denormalize(normalized_data, cfg) rounded_data = denormalized_data.copy() for i in range(rounded_data.shape[0]): rounded_data[i] = _make_divisible(rounded_data[i], cfg.data.min_value[i]) flops, params = net2flops(list(rounded_data.astype(int)), device) pickle.dump(edit_net_set, open(f"{cfg.log_dir}/NCP.pkl", "wb"))
class NormalPolicy(): def __init__(self, layers, sigma, activation=F.relu): self.mu_net = MLP(layers, activation) self.sigma = MLP(layers, activation=F.softplus) # self.mu_net.fc1.weight.data = torch.zeros(self.mu_net.fc1.weight.data.shape) # self.mu_net.eta.data = torch.ones(1) * 2 def get_mu(self, states): return self.mu_net.forward(states) def get_sigma(self, states): return self.sigma.forward(states) def get_action(self, state): # random action if untrained # if self.initial_policy is not None: # return self.initial_policy.get_action(state) # sample from normal otherwise if state.dim() < 2: state.unsqueeze_(0) mean = self.get_mu(state) std_dev = self.get_sigma(state) mean.squeeze() std_dev.squeeze() m = torch.normal(mean, std_dev) return m.data def optimize(self, max_epochs_opt, train_dataset, val_dataset, batch_size, learning_rate, verbose=False): # init data loader train_data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4) # init optimizers optimizer_mu = optim.Adagrad([{'params': self.mu_net.parameters()}, {'params':self.sigma.parameters()}], lr=learning_rate) # train on batches best_model = None last_loss_opt = None epochs_opt_no_decrease = 0 epoch_opt = 0 while (epoch_opt < max_epochs_opt) and (epochs_opt_no_decrease < 3): for batch_idx, batch in enumerate(train_data_loader): optimizer_mu.zero_grad() # forward pass mu = self.mu_net(batch[0]) sigma = self.get_sigma(batch[0]) loss = NormalPolicyLoss(mu, sigma, batch[1], batch[2]) # backpropagate loss.backward() optimizer_mu.step() # calculate loss on validation data mu = self.get_mu(val_dataset[0]) sigma = self.get_sigma(val_dataset[0]) cur_loss_opt = NormalPolicyLoss(mu, sigma, val_dataset[1], val_dataset[2]) # evaluate optimization iteration if verbose: sys.stdout.write('\r[policy] epoch: %d | loss: %f' % (epoch_opt+1, cur_loss_opt)) sys.stdout.flush() if (last_loss_opt is None) or (cur_loss_opt < last_loss_opt - 1e-3): best_model = self.mu_net.state_dict() epochs_opt_no_decrease = 0 last_loss_opt = cur_loss_opt else: epochs_opt_no_decrease += 1 epoch_opt += 1 self.mu_net.load_state_dict(best_model) if verbose: sys.stdout.write('\r[policy] training complete (%d epochs, %f best loss)' % (epoch_opt, last_loss_opt) + (' ' * (len(str(epoch_opt)))*2 + '\n'))