def model_train(file_path, dataset, model_type, savedir, **kwargs): # load data # TODO: 아래 나머지 채울것 if dataset == 'titanic': titanic = TitanicData(file_path) (x_train, y_train), _ = titanic.transform(scaling=kwargs.pop('scaling')) elif dataset == 'house_price': house_price = HousePriceData(file_path) (x_train, y_train), _ = house_price.transform(scaling=kwargs.pop('scaling')) elif dataset == 'bike_sharing': pass elif dataset == 'cervical_cancer': cervical_cancer = CervicalCancerData(file_path) x_train, y_train = cervical_cancer.transform( scaling=kwargs.pop('scaling')) elif dataset == 'youtube_spam': pass print('Complete Data Pre-processing') # add argument if model_type == 'DNN': kwargs['params']['nb_features'] = x_train.shape[1] # model training clf = Classifier(model_type=model_type, **kwargs.pop('params')) clf.train(x_train, y_train, savedir, **kwargs) print('Complete Training Model') print('Complete Saving Model')
def train_and_predict(train_df, test_df): # Data Cleaning # clean the data cleaner = DataCleaner() cleaner.columns_with_no_nan(train_df) cleaner.columns_with_no_nan(test_df) train_df = cleaner.drop_columns(train_df) train_df = cleaner.resolve_nan(train_df) test_df = cleaner.drop_columns(test_df) test_df = cleaner.resolve_nan(test_df) # features engineering train_df, test_df = engineer_features(train_df, test_df) # train the model from Model model = Classifier() model = model.model() # LabelEncoding/OneHotEncoding? train_df = model.encode(train_df) test_df = model.encode(test_df) # training progress and results model = model.train(model, train_df) # predict on test_df with predict method from Model y_test = model.predict(model, test_df) return y_test
class Agent(): def __init__(self, state_size, action_size, config): self.env_name = config["env_name"] self.state_size = state_size self.action_size = action_size self.seed = config["seed"] self.clip = config["clip"] self.device = 'cuda' print("Clip ", self.clip) print("cuda ", torch.cuda.is_available()) self.double_dqn = config["DDQN"] print("Use double dqn", self.double_dqn) self.lr_pre = config["lr_pre"] self.batch_size = config["batch_size"] self.lr = config["lr"] self.tau = config["tau"] print("self tau", self.tau) self.gamma = 0.99 self.fc1 = config["fc1_units"] self.fc2 = config["fc2_units"] self.fc3 = config["fc3_units"] self.qnetwork_local = QNetwork(state_size, action_size, self.fc1, self.fc2, self.fc3, self.seed).to(self.device) self.qnetwork_target = QNetwork(state_size, action_size, self.fc1, self.fc2,self.fc3, self.seed).to(self.device) self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=self.lr) self.soft_update(self.qnetwork_local, self.qnetwork_target, 1) self.q_shift_local = QNetwork(state_size, action_size, self.fc1, self.fc2, self.fc3, self.seed).to(self.device) self.q_shift_target = QNetwork(state_size, action_size, self.fc1, self.fc2, self.fc3, self.seed).to(self.device) self.optimizer_shift = optim.Adam(self.q_shift_local.parameters(), lr=self.lr) self.soft_update(self.q_shift_local, self.q_shift_target, 1) self.R_local = QNetwork(state_size, action_size, self.fc1, self.fc2, self.fc3, self.seed).to(self.device) self.R_target = QNetwork(state_size, action_size, self.fc1, self.fc2, self.fc3, self.seed).to(self.device) self.optimizer_r = optim.Adam(self.R_local.parameters(), lr=self.lr) self.soft_update(self.R_local, self.R_target, 1) self.expert_q = DQNetwork(state_size, action_size, seed=self.seed).to(self.device) self.expert_q.load_state_dict(torch.load('checkpoint.pth')) self.memory = Memory(action_size, config["buffer_size"], self.batch_size, self.seed, self.device) self.t_step = 0 self.steps = 0 self.predicter = Classifier(state_size, action_size, self.seed).to(self.device) self.optimizer_pre = optim.Adam(self.predicter.parameters(), lr=self.lr_pre) pathname = "lr_{}_batch_size_{}_fc1_{}_fc2_{}_fc3_{}_seed_{}".format(self.lr, self.batch_size, self.fc1, self.fc2, self.fc3, self.seed) pathname += "_clip_{}".format(config["clip"]) pathname += "_tau_{}".format(config["tau"]) now = datetime.now() dt_string = now.strftime("%d_%m_%Y_%H:%M:%S") pathname += dt_string tensorboard_name = str(config["locexp"]) + '/runs/' + pathname self.writer = SummaryWriter(tensorboard_name) print("summery writer ", tensorboard_name) self.average_prediction = deque(maxlen=100) self.average_same_action = deque(maxlen=100) self.all_actions = [] for a in range(self.action_size): action = torch.Tensor(1) * 0 + a self.all_actions.append(action.to(self.device)) def learn(self, memory): logging.debug("--------------------------New episode-----------------------------------------------") states, next_states, actions, dones = memory.expert_policy(self.batch_size) self.steps += 1 self.state_action_frq(states, actions) self.compute_shift_function(states, next_states, actions, dones) for i in range(1): for a in range(self.action_size): action = torch.ones([self.batch_size, 1], device= self.device) * a self.compute_r_function(states, action) self.compute_q_function(states, next_states, actions, dones) self.soft_update(self.q_shift_local, self.q_shift_target, self.tau) self.soft_update(self.R_local, self.R_target, self.tau) self.soft_update(self.qnetwork_local, self.qnetwork_target, self.tau) return def learn_predicter(self, memory): """ """ states, next_states, actions, dones = memory.expert_policy(self.batch_size) self.state_action_frq(states, actions) def state_action_frq(self, states, action): """ Train classifer to compute state action freq """ self.predicter.train() output = self.predicter(states, train=True) output = output.squeeze(0) # logging.debug("out predicter {})".format(output)) y = action.type(torch.long).squeeze(1) #print("y shape", y.shape) loss = nn.CrossEntropyLoss()(output, y) self.optimizer_pre.zero_grad() loss.backward() #torch.nn.utils.clip_grad_norm_(self.predicter.parameters(), 1) self.optimizer_pre.step() self.writer.add_scalar('Predict_loss', loss, self.steps) self.predicter.eval() def test_predicter(self, memory): """ """ self.predicter.eval() same_state_predition = 0 for i in range(memory.idx): states = memory.obses[i] actions = memory.actions[i] states = torch.as_tensor(states, device=self.device).unsqueeze(0) actions = torch.as_tensor(actions, device=self.device) output = self.predicter(states) output = F.softmax(output, dim=1) # create one hot encode y from actions y = actions.type(torch.long).item() p =torch.argmax(output.data).item() if y==p: same_state_predition += 1 #self.average_prediction.append(same_state_predition) #average_pred = np.mean(self.average_prediction) #self.writer.add_scalar('Average prediction acc', average_pred, self.steps) #logging.debug("Same prediction {} of 100".format(same_state_predition)) text = "Same prediction {} of {} ".format(same_state_predition, memory.idx) print(text) # self.writer.add_scalar('Action prediction acc', same_state_predition, self.steps) self.predicter.train() def get_action_prob(self, states, actions): """ """ actions = actions.type(torch.long) # check if action prob is zero output = self.predicter(states) output = F.softmax(output, dim=1) # print("get action_prob ", output) # output = output.squeeze(0) action_prob = output.gather(1, actions) action_prob = action_prob + torch.finfo(torch.float32).eps # check if one action if its to small if action_prob.shape[0] == 1: if action_prob.cpu().detach().numpy()[0][0] < 1e-4: return None # logging.debug("action_prob {})".format(action_prob)) action_prob = torch.log(action_prob) action_prob = torch.clamp(action_prob, min= self.clip, max=0) return action_prob def compute_shift_function(self, states, next_states, actions, dones): """Update value parameters using given batch of experience tuples. Params ====== experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples gamma (float): discount factor """ actions = actions.type(torch.int64) with torch.no_grad(): # Get max predicted Q values (for next states) from target model if self.double_dqn: qt = self.q_shift_local(next_states) max_q, max_actions = qt.max(1) Q_targets_next = self.qnetwork_target(next_states).gather(1, max_actions.unsqueeze(1)) else: Q_targets_next = self.qnetwork_target(next_states).max(1)[0].unsqueeze(1) # Compute Q targets for current states Q_targets = (self.gamma * Q_targets_next * (dones)) # Get expected Q values from local model Q_expected = self.q_shift_local(states).gather(1, actions) # Compute loss loss = F.mse_loss(Q_expected, Q_targets) # Minimize the loss self.optimizer_shift.zero_grad() loss.backward() self.writer.add_scalar('Shift_loss', loss, self.steps) self.optimizer_shift.step() def compute_r_function(self, states, actions, debug=False, log=False): """ """ actions = actions.type(torch.int64) # sum all other actions # print("state shape ", states.shape) size = states.shape[0] idx = 0 all_zeros = [] with torch.no_grad(): y_shift = self.q_shift_target(states).gather(1, actions) log_a = self.get_action_prob(states, actions) index_list = index_None_value(log_a) # print("is none", index_list) if index_list is None: return y_r_part1 = log_a - y_shift y_r_part2 = torch.empty((size, 1), dtype=torch.float32).to(self.device) for a, s in zip(actions, states): y_h = 0 taken_actions = 0 for b in self.all_actions: b = b.type(torch.int64).unsqueeze(1) n_b = self.get_action_prob(s.unsqueeze(0), b) if torch.eq(a, b) or n_b is None: logging.debug("best action {} ".format(a)) logging.debug("n_b action {} ".format(b)) logging.debug("n_b {} ".format(n_b)) continue taken_actions += 1 r_hat = self.R_target(s.unsqueeze(0)).gather(1, b) y_s = self.q_shift_target(s.unsqueeze(0)).gather(1, b) n_b = n_b - y_s y_h += (r_hat - n_b) if debug: print("action", b.item()) print("r_pre {:.3f}".format(r_hat.item())) print("n_b {:.3f}".format(n_b.item())) if taken_actions == 0: all_zeros.append(idx) else: y_r_part2[idx] = (1. / taken_actions) * y_h idx += 1 #print(y_r_part2, y_r_part1) y_r = y_r_part1 + y_r_part2 #print("_________________") #print("r update zeros ", len(all_zeros)) if len(index_list) > 0: print("none list", index_list) y = self.R_local(states).gather(1, actions) if log: text = "Action {:.2f} y target {:.2f} = n_a {:.2f} + {:.2f} and pre{:.2f}".format(actions.item(), y_r.item(), y_r_part1.item(), y_r_part2.item(), y.item()) logging.debug(text) if debug: print("expet action ", actions.item()) # print("y r {:.3f}".format(y.item())) # print("log a prob {:.3f}".format(log_a.item())) # print("n_a {:.3f}".format(y_r_part1.item())) print("Correct action p {:.3f} ".format(y.item())) print("Correct action target {:.3f} ".format(y_r.item())) print("part1 corret action {:.2f} ".format(y_r_part1.item())) print("part2 incorret action {:.2f} ".format(y_r_part2.item())) #print("y", y.shape) #print("y_r", y_r.shape) r_loss = F.mse_loss(y, y_r) #con = input() #sys.exit() # Minimize the loss self.optimizer_r.zero_grad() r_loss.backward() #torch.nn.utils.clip_grad_norm_(self.R_local.parameters(), 5) self.optimizer_r.step() self.writer.add_scalar('Reward_loss', r_loss, self.steps) if debug: print("after update r pre ", self.R_local(states).gather(1, actions).item()) print("after update r target ", self.R_target(states).gather(1, actions).item()) # ------------------- update target network ------------------- # #self.soft_update(self.R_local, self.R_target, 5e-3) if debug: print("after soft upda r target ", self.R_target(states).gather(1, actions).item()) def compute_q_function(self, states, next_states, actions, dones, debug=False, log= False): """Update value parameters using given batch of experience tuples. Params ====== experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples gamma (float): discount factor """ actions = actions.type(torch.int64) if debug: print("---------------q_update------------------") print("expet action ", actions.item()) print("state ", states) with torch.no_grad(): # Get max predicted Q values (for next states) from target model if self.double_dqn: qt = self.qnetwork_local(next_states) max_q, max_actions = qt.max(1) Q_targets_next = self.qnetwork_target(next_states).gather(1, max_actions.unsqueeze(1)) else: Q_targets_next = self.qnetwork_target(next_states).max(1)[0].unsqueeze(1) # Compute Q targets for current states rewards = self.R_target(states).gather(1, actions) Q_targets = rewards + (self.gamma * Q_targets_next * (dones)) if debug: print("reward {}".format(rewards.item())) print("Q target next {}".format(Q_targets_next.item())) print("Q_target {}".format(Q_targets.item())) # Get expected Q values from local model Q_expected = self.qnetwork_local(states).gather(1, actions) if log: text = "Action {:.2f} q target {:.2f} = r_a {:.2f} + target {:.2f} and pre{:.2f}".format(actions.item(), Q_targets.item(), rewards.item(), Q_targets_next.item(), Q_expected.item()) logging.debug(text) if debug: print("q for a {}".format(Q_expected)) # Compute loss loss = F.mse_loss(Q_expected, Q_targets) self.writer.add_scalar('Q_loss', loss, self.steps) # Minimize the loss self.optimizer.zero_grad() loss.backward() self.optimizer.step() if debug: print("q after update {}".format(self.qnetwork_local(states))) print("q loss {}".format(loss.item())) # ------------------- update target network ------------------- # def dqn_train(self, n_episodes=2000, max_t=1000, eps_start=1.0, eps_end=0.01, eps_decay=0.995): env = gym.make('LunarLander-v2') scores = [] # list containing scores from each episode scores_window = deque(maxlen=100) # last 100 scores eps = eps_start for i_episode in range(1, n_episodes+1): state = env.reset() score = 0 for t in range(max_t): self.t_step += 1 action = self.dqn_act(state, eps) next_state, reward, done, _ = env.step(action) self.step(state, action, reward, next_state, done) state = next_state score += reward if done: self.test_q() break scores_window.append(score) # save most recent score scores.append(score) # save most recent score eps = max(eps_end, eps_decay*eps) # decrease epsilon print('\rEpisode {}\tAverage Score: {:.2f}'.format(i_episode, np.mean(scores_window)), end="") if i_episode % 100 == 0: print('\rEpisode {}\tAverage Score: {:.2f}'.format(i_episode, np.mean(scores_window))) if np.mean(scores_window)>=200.0: print('\nEnvironment solved in {:d} episodes!\tAverage Score: {:.2f}'.format(i_episode-100, np.mean(scores_window))) break def test_policy(self): env = gym.make('LunarLander-v2') logging.debug("new episode") average_score = [] average_steps = [] average_action = [] for i in range(5): state = env.reset() score = 0 same_action = 0 logging.debug("new episode") for t in range(200): state = torch.from_numpy(state).float().unsqueeze(0).to(self.device) q_expert = self.expert_q(state) q_values = self.qnetwork_local(state) logging.debug("q expert a0: {:.2f} a1: {:.2f} a2: {:.2f} a3: {:.2f}".format(q_expert.data[0][0], q_expert.data[0][1], q_expert.data[0][2], q_expert.data[0][3])) logging.debug("q values a0: {:.2f} a1: {:.2f} a2: {:.2f} a3: {:.2f} )".format(q_values.data[0][0], q_values.data[0][1], q_values.data[0][2], q_values.data[0][3])) action = torch.argmax(q_values).item() action_e = torch.argmax(q_expert).item() if action == action_e: same_action += 1 next_state, reward, done, _ = env.step(action) state = next_state score += reward if done: average_score.append(score) average_steps.append(t) average_action.append(same_action) break mean_steps = np.mean(average_steps) mean_score = np.mean(average_score) mean_action= np.mean(average_action) self.writer.add_scalar('Ave_epsiode_length', mean_steps , self.steps) self.writer.add_scalar('Ave_same_action', mean_action, self.steps) self.writer.add_scalar('Ave_score', mean_score, self.steps) def step(self, state, action, reward, next_state, done): # Save experience in replay memory self.memory.add(state, action, reward, next_state, done) # Learn every UPDATE_EVERY time steps. self.t_step = (self.t_step + 1) % 4 if self.t_step == 0: # If enough samples are available in memory, get random subset and learn if len(self.memory) > self.batch_size: experiences = self.memory.sample() self.update_q(experiences) def dqn_act(self, state, eps=0.): """Returns actions for given state as per current policy. Params ====== state (array_like): current state eps (float): epsilon, for epsilon-greedy action selection """ state = torch.from_numpy(state).float().unsqueeze(0).to(device) self.qnetwork_local.eval() with torch.no_grad(): action_values = self.qnetwork_local(state) self.qnetwork_local.train() # Epsilon-greedy action selection if random.random() > eps: return np.argmax(action_values.cpu().data.numpy()) else: return random.choice(np.arange(self.action_size)) def update_q(self, experiences, debug=False): """Update value parameters using given batch of experience tuples. Params ====== experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples gamma (float): discount factor """ states, actions, rewards, next_states, dones = experiences # Get max predicted Q values (for next states) from target model with torch.no_grad(): Q_targets_next = self.qnetwork_target(next_states).max(1)[0].unsqueeze(1) # Compute Q targets for current states Q_targets = rewards + (self.gamma * Q_targets_next * (1 - dones)) # Get expected Q values from local model Q_expected = self.qnetwork_local(states).gather(1, actions) if debug: print("----------------------") print("----------------------") print("Q target", Q_targets) print("pre", Q_expected) print("all local",self.qnetwork_local(states)) # Compute loss loss = F.mse_loss(Q_expected, Q_targets) # Minimize the loss self.optimizer.zero_grad() loss.backward() self.optimizer.step() # ------------------- update target network ------------------- # self.soft_update(self.qnetwork_local, self.qnetwork_target) def test_q(self): experiences = self.memory.test_sample() self.update_q(experiences, True) def test_q_value(self, memory): same_action = 0 test_elements = memory.idx all_diff = 0 error = True self.predicter.eval() for i in range(test_elements): # print("lop", i) states = memory.obses[i] next_states = memory.next_obses[i] actions = memory.actions[i] dones = memory.not_dones[i] states = torch.as_tensor(states, device=self.device).unsqueeze(0) next_states = torch.as_tensor(next_states, device=self.device) actions = torch.as_tensor(actions, device=self.device) dones = torch.as_tensor(dones, device=self.device) with torch.no_grad(): output = self.predicter(states) output = F.softmax(output, dim=1) q_values = self.qnetwork_local(states) expert_values = self.expert_q(states) print("q values ", q_values) print("ex values ", expert_values) best_action = torch.argmax(q_values).item() actions = actions.type(torch.int64) q_max = q_values.max(1) #print("q values", q_values) q = q_values[0][actions.item()].item() #print("q action", q) max_q = q_max[0].data.item() diff = max_q - q all_diff += diff #print("q best", max_q) #print("difference ", diff) if actions.item() != best_action: r = self.R_local(states) rt = self.R_target(states) qt = self.qnetwork_target(states) logging.debug("------------------false action --------------------------------") logging.debug("expert action {})".format(actions.item())) logging.debug("out predicter a0: {:.2f} a1: {:.2f} a2: {:.2f} a3: {:.2f} )".format(output.data[0][0], output.data[0][1], output.data[0][2], output.data[0][3])) logging.debug("q values a0: {:.2f} a1: {:.2f} a2: {:.2f} a3: {:.2f} )".format(q_values.data[0][0], q_values.data[0][1], q_values.data[0][2], q_values.data[0][3])) logging.debug("q target a0: {:.2f} a1: {:.2f} a2: {:.2f} a3: {:.2f} )".format(qt.data[0][0], qt.data[0][1], qt.data[0][2], qt.data[0][3])) logging.debug("rewards a0: {:.2f} a1: {:.2f} a2: {:.2f} a3: {:.2f} )".format(r.data[0][0], r.data[0][1], r.data[0][2], r.data[0][3])) logging.debug("re target a0: {:.2f} a1: {:.2f} a2: {:.2f} a3: {:.2f} )".format(rt.data[0][0], rt.data[0][1], rt.data[0][2], rt.data[0][3])) """ logging.debug("---------Reward Function------------") action = torch.Tensor(1) * 0 + 0 self.compute_r_function(states, action.unsqueeze(0).to(self.device), log= True) action = torch.Tensor(1) * 0 + 1 self.compute_r_function(states, action.unsqueeze(0).to(self.device), log= True) action = torch.Tensor(1) * 0 + 2 self.compute_r_function(states, action.unsqueeze(0).to(self.device), log= True) action = torch.Tensor(1) * 0 + 3 self.compute_r_function(states, action.unsqueeze(0).to(self.device), log= True) logging.debug("------------------Q Function --------------------------------") action = torch.Tensor(1) * 0 + 0 self.compute_q_function(states, next_states.unsqueeze(0), action.unsqueeze(0).to(self.device), dones, log= True) action = torch.Tensor(1) * 0 + 1 self.compute_q_function(states, next_states.unsqueeze(0), action.unsqueeze(0).to(self.device), dones, log= True) action = torch.Tensor(1) * 0 + 2 self.compute_q_function(states, next_states.unsqueeze(0), action.unsqueeze(0).to(self.device), dones, log= True) action = torch.Tensor(1) * 0 + 3 self.compute_q_function(states, next_states.unsqueeze(0), action.unsqueeze(0).to(self.device), dones, log= True) """ if actions.item() == best_action: same_action += 1 continue print("-------------------------------------------------------------------------------") print("state ", i) print("expert ", actions) print("q values", q_values.data) print("action prob predicter ", output.data) self.compute_r_function(states, actions.unsqueeze(0), True) self.compute_q_function(states, next_states.unsqueeze(0), actions.unsqueeze(0), dones, True) else: if error: continue print("-------------------------------------------------------------------------------") print("expert action ", actions.item()) print("best action q ", best_action) print(i) error = False continue # logging.debug("experte action {} q fun {}".format(actions.item(), q_values)) print("-------------------------------------------------------------------------------") print("state ", i) print("expert ", actions) print("q values", q_values.data) print("action prob predicter ", output.data) self.compute_r_function(states, actions.unsqueeze(0), True) self.compute_q_function(states, next_states.unsqueeze(0), actions.unsqueeze(0), dones, True) self.writer.add_scalar('diff', all_diff, self.steps) self.average_same_action.append(same_action) av_action = np.mean(self.average_same_action) self.writer.add_scalar('Same_action', same_action, self.steps) print("Same actions {} of {}".format(same_action, test_elements)) self.predicter.train() def soft_update(self, local_model, target_model, tau=4): """Soft update model parameters. θ_target = τ*θ_local + (1 - τ)*θ_target Params ====== local_model (PyTorch model): weights will be copied from target_model (PyTorch model): weights will be copied to tau (float): interpolation parameter """ # print("use tau", tau) for target_param, local_param in zip(target_model.parameters(), local_model.parameters()): target_param.data.copy_(tau * local_param.data + (1.0 - tau) * target_param.data) def save(self, filename): """ """ mkdir("", filename) torch.save(self.predicter.state_dict(), filename + "_predicter.pth") torch.save(self.optimizer_pre.state_dict(), filename + "_predicter_optimizer.pth") torch.save(self.qnetwork_local.state_dict(), filename + "_q_net.pth") """ torch.save(self.optimizer_q.state_dict(), filename + "_q_net_optimizer.pth") torch.save(self.q_shift_local.state_dict(), filename + "_q_shift_net.pth") torch.save(self.optimizer_q_shift.state_dict(), filename + "_q_shift_net_optimizer.pth") """ print("save models to {}".format(filename)) def load(self, filename): self.predicter.load_state_dict(torch.load(filename + "_predicter.pth")) self.optimizer_pre.load_state_dict(torch.load(filename + "_predicter_optimizer.pth")) print("Load models to {}".format(filename))
ncols=get_tty_columns(), dynamic_ncols=True, desc='[%s] Loss: %.5f, Accu: %.5f' % (stage_info[stage], 0.0, 0.0)) avg_loss = 0.0 avg_accu = 0.0 TP = 0.0 FP = 0.0 FN = 0.0 TN = 0.0 counter = 0 training = (stage == 'train') if training: model.train() else: model.eval() non_zero = False for data in progress_bar: # labels = (torch.sum(data['label'], dim=(1, 2, 3)) > 0).long() # labels = torch.unsqueeze(labels, dim=1).float() labels = data['label'] images = data['image'] if torch.sum(images) > 0: if torch.cuda.device_count() > 0: labels = labels.cuda()
else: # load with open(config_path, 'r') as f: config = json.load(f) # split img_gen_config = config['img_gen_config'] model_config = config['model_config'] train_config = config['train_config'] # warn warn_message = f'Using config from {config_path_rel}' warnings.warn(warn_message) # build img data generators train_generator, validation_generator, test_generator, class_names = build_set_generators( **img_gen_config) # define model cl = Classifier(img_gen_config=img_gen_config, model_config=model_config, input_shape=train_generator.x.shape[1:]) cl.class_names = class_names # cl.train(train_generator, validation_generator, train_config) cl.train(train_generator, validation_generator, train_config) # eval cl.evaluate(test_generator) # save cl.save()
class Agent(): def __init__(self, state_size, action_size, config): self.seed = config["seed"] torch.manual_seed(self.seed) np.random.seed(seed=self.seed) random.seed(self.seed) self.env = gym.make(config["env_name"]) self.env.seed(self.seed) self.state_size = state_size self.action_size = action_size self.clip = config["clip"] self.device = 'cuda' print("Clip ", self.clip) print("cuda ", torch.cuda.is_available()) self.double_dqn = config["DDQN"] print("Use double dqn", self.double_dqn) self.lr_pre = config["lr_pre"] self.batch_size = config["batch_size"] self.lr = config["lr"] self.tau = config["tau"] print("self tau", self.tau) self.gamma = 0.99 self.target_entropy = -torch.prod(torch.Tensor(action_size).to(self.device)).item() self.fc1 = config["fc1_units"] self.fc2 = config["fc2_units"] self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device) self.alpha = self.log_alpha.exp() self.alpha_optim = optim.Adam([self.log_alpha], lr=config["lr_alpha"]) self.policy = SACActor(state_size, action_size, self.seed).to(self.device) self.policy_optim = optim.Adam(self.policy.parameters(), lr=config["lr_policy"]) self.qnetwork_local = QNetwork(state_size, action_size, self.seed, self.fc1, self.fc2).to(self.device) self.qnetwork_target = QNetwork(state_size, action_size, self.seed, self.fc1, self.fc2).to(self.device) self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=self.lr) self.soft_update(self.qnetwork_local, self.qnetwork_target, 1) self.q_shift_local = SQNetwork(state_size, action_size, self.seed, self.fc1, self.fc2).to(self.device) self.q_shift_target = SQNetwork(state_size, action_size,self.seed, self.fc1, self.fc2).to(self.device) self.optimizer_shift = optim.Adam(self.q_shift_local.parameters(), lr=self.lr) self.soft_update(self.q_shift_local, self.q_shift_target, 1) self.R_local = SQNetwork(state_size, action_size, self.seed, self.fc1, self.fc2).to(self.device) self.R_target = SQNetwork(state_size, action_size, self.seed, self.fc1, self.fc2).to(self.device) self.optimizer_r = optim.Adam(self.R_local.parameters(), lr=self.lr) self.soft_update(self.R_local, self.R_target, 1) self.steps = 0 self.predicter = Classifier(state_size, action_size, self.seed, 256, 256).to(self.device) self.optimizer_pre = optim.Adam(self.predicter.parameters(), lr=self.lr_pre) pathname = "lr_{}_batch_size_{}_fc1_{}_fc2_{}_seed_{}".format(self.lr, self.batch_size, self.fc1, self.fc2, self.seed) pathname += "_clip_{}".format(config["clip"]) pathname += "_tau_{}".format(config["tau"]) now = datetime.now() dt_string = now.strftime("%d_%m_%Y_%H:%M:%S") pathname += dt_string tensorboard_name = str(config["locexp"]) + '/runs/' + pathname self.vid_path = str(config["locexp"]) + '/vid' self.writer = SummaryWriter(tensorboard_name) print("summery writer ", tensorboard_name) self.average_prediction = deque(maxlen=100) self.average_same_action = deque(maxlen=100) self.all_actions = [] for a in range(self.action_size): action = torch.Tensor(1) * 0 + a self.all_actions.append(action.to(self.device)) def learn(self, memory_ex, memory_all): self.steps += 1 logging.debug("--------------------------New update-----------------------------------------------") states, next_states, actions, dones = memory_ex.expert_policy(self.batch_size) self.state_action_frq(states, actions) states, next_states, actions, dones = memory_all.expert_policy(self.batch_size) self.compute_shift_function(states, next_states, actions, dones) self.compute_r_function(states, actions) self.compute_q_function(states, next_states, actions, dones) self.soft_update(self.R_local, self.R_target, self.tau) self.soft_update(self.q_shift_local, self.q_shift_target, self.tau) self.soft_update(self.qnetwork_local, self.qnetwork_target, self.tau) return def compute_q_function(self, states, next_states, actions, dones): """Update value parameters using given batch of experience tuples. Params ====== experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples gamma (float): discount factor """ qf1, qf2 = self.qnetwork_local(states) q_value1 = qf1.gather(1, actions) q_value2 = qf2.gather(1, actions) with torch.no_grad(): q1_target, q2_target = self.qnetwork_target(next_states) min_q_target = torch.min(q1_target, q2_target) next_action_prob, next_action_log_prob = self.policy(next_states) next_q_target = (next_action_prob * (min_q_target - self.alpha * next_action_log_prob)).sum(dim=1, keepdim=True) rewards = self.R_target(states).detach().gather(1, actions.detach()).squeeze(0) Q_targets = rewards + ((1 - dones) * self.gamma * next_q_target) loss = F.mse_loss(q_value2, Q_targets.detach()) + F.mse_loss(q_value1, Q_targets.detach()) # Get max predicted Q values (for next states) from target model self.writer.add_scalar('losss/q_loss', loss, self.steps) # Minimize the loss self.optimizer.zero_grad() loss.backward() # torch.nn.utils.clip_grad_norm_(self.qnetwork_local.parameters(), 1) self.optimizer.step() # --------------------------update-policy-------------------------------------------------------- action_prob, log_action_prob = self.policy(states) with torch.no_grad(): q_pi1, q_pi2 = self.qnetwork_local(states) min_q_values = torch.min(q_pi1, q_pi2) #policy_loss = (action_prob * ((self.alpha * log_action_prob) - min_q_values).detach()).sum(dim=1).mean() policy_loss = (action_prob * ((self.alpha * log_action_prob) - min_q_values)).sum(dim=1).mean() self.policy_optim.zero_grad() policy_loss.backward() self.policy_optim.step() self.writer.add_scalar('loss/policy', policy_loss, self.steps) # --------------------------update-alpha-------------------------------------------------------- alpha_loss =(action_prob.detach() * (-self.log_alpha * (log_action_prob + self.target_entropy).detach())).sum(dim=1).mean() self.alpha_optim.zero_grad() alpha_loss.backward() self.alpha_optim.step() self.writer.add_scalar('loss/alpha', alpha_loss, self.steps) self.alpha = self.log_alpha.exp() def compute_shift_function(self, states, next_states, actions, dones): """Update value parameters using given batch of experience tuples. Params ====== experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples gamma (float): discount factor """ actions = actions.type(torch.int64) with torch.no_grad(): # Get max predicted Q values (for next states) from target model #if self.double_dqn: qt1, qt2 = self.qnetwork_local(next_states) q_min = torch.min(qt1, qt2) max_q, max_actions = q_min.max(1) Q_targets_next1, Q_targets_next2 = self.qnetwork_target(next_states) Q_targets_next = torch.min(Q_targets_next1, Q_targets_next2) Q_targets_next = Q_targets_next.gather(1, max_actions.type(torch.int64).unsqueeze(1)) #else: #Q_targets_next = self.qnetwork_target(next_states).detach().max(1)[0].unsqueeze(1) # Compute Q targets for current states Q_targets = self.gamma * Q_targets_next * (dones) # Get expected Q values from local model Q_expected = self.q_shift_local(states).gather(1, actions) # Compute loss loss = F.mse_loss(Q_expected, Q_targets.detach()) # Minimize the loss self.optimizer_shift.zero_grad() loss.backward() self.writer.add_scalar('Shift_loss', loss, self.steps) self.optimizer_shift.step() def compute_r_function(self, states, actions, debug=False, log=False): actions = actions.type(torch.int64) # sum all other actions # print("state shape ", states.shape) size = states.shape[0] idx = 0 all_zeros = [1 for i in range(actions.shape[0])] zeros = False y_shift = self.q_shift_target(states).gather(1, actions).detach() log_a = self.get_action_prob(states, actions).detach() y_r_part1 = log_a - y_shift y_r_part2 = torch.empty((size, 1), dtype=torch.float32).to(self.device) for a, s in zip(actions, states): y_h = 0 taken_actions = 0 for b in self.all_actions: b = b.type(torch.int64).unsqueeze(1) n_b = self.get_action_prob(s.unsqueeze(0), b) if torch.eq(a, b) or n_b is None: continue taken_actions += 1 y_s = self.q_shift_target(s.unsqueeze(0)).detach().gather(1, b).item() n_b = n_b.data.item() - y_s r_hat = self.R_target(s.unsqueeze(0)).gather(1, b).item() y_h += (r_hat - n_b) if log: text = "a {} r _hat {:.2f} - n_b {:.2f} | sh {:.2f} ".format(b.item(), r_hat, n_b, y_s) logging.debug(text) if taken_actions == 0: all_zeros[idx] = 0 zeros = True y_r_part2[idx] = 0.0 else: y_r_part2[idx] = (1. / taken_actions) * y_h idx += 1 y_r = y_r_part1 + y_r_part2 # check if there are zeros (no update for this tuble) remove them from states and if zeros: #print(all_zeros) #print(states) #print(actions) mask = torch.BoolTensor(all_zeros) states = states[mask] actions = actions[mask] y_r = y_r[mask] y = self.R_local(states).gather(1, actions) if log: text = "Action {:.2f} r target {:.2f} = n_a {:.2f} + n_b {:.2f} y {:.2f}".format(actions[0].item(), y_r[0].item(), y_r_part1[0].item(), y_r_part2[0].item(), y[0].item()) logging.debug(text) r_loss = F.mse_loss(y, y_r.detach()) # sys.exit() # Minimize the loss self.optimizer_r.zero_grad() r_loss.backward() # torch.nn.utils.clip_grad_norm_(self.R_local.parameters(), 5) self.optimizer_r.step() self.writer.add_scalar('Reward_loss', r_loss, self.steps) def get_action_prob(self, states, actions): """ """ actions = actions.type(torch.long) # check if action prob is zero output = self.predicter(states) output = F.softmax(output, dim=1) action_prob = output.gather(1, actions) action_prob = action_prob + torch.finfo(torch.float32).eps # check if one action if its to small if action_prob.shape[0] == 1: if action_prob.cpu().detach().numpy()[0][0] < 1e-4: return None action_prob = torch.log(action_prob) action_prob = torch.clamp(action_prob, min= self.clip, max=0) return action_prob def state_action_frq(self, states, action): """ Train classifer to compute state action freq """ self.predicter.train() output = self.predicter(states, train=True) output = output.squeeze(0) # logging.debug("out predicter {})".format(output)) y = action.type(torch.long).squeeze(1) #print("y shape", y.shape) loss = nn.CrossEntropyLoss()(output, y) self.optimizer_pre.zero_grad() loss.backward() #torch.nn.utils.clip_grad_norm_(self.predicter.parameters(), 1) self.optimizer_pre.step() self.writer.add_scalar('Predict_loss', loss, self.steps) self.predicter.eval() def test_predicter(self, memory): """ """ self.predicter.eval() same_state_predition = 0 for i in range(memory.idx): states = memory.obses[i] actions = memory.actions[i] states = torch.as_tensor(states, device=self.device).unsqueeze(0) actions = torch.as_tensor(actions, device=self.device) output = self.predicter(states) output = F.softmax(output, dim=1) #print("state 0", output.data) # create one hot encode y from actions y = actions.type(torch.long).item() p = torch.argmax(output.data).item() #print("a {} p {}".format(y, p)) text = "r {}".format(self.R_local(states.detach()).detach()) #print(text) if y==p: same_state_predition += 1 text = "Same prediction {} of {} ".format(same_state_predition, memory.idx) print(text) logging.debug(text) def soft_update(self, local_model, target_model, tau=4): """Soft update model parameters. θ_target = τ*θ_local + (1 - τ)*θ_target Params ====== local_model (PyTorch model): weights will be copied from target_model (PyTorch model): weights will be copied to tau (float): interpolation parameter """ # print("use tau", tau) for target_param, local_param in zip(target_model.parameters(), local_model.parameters()): target_param.data.copy_(tau * local_param.data + (1.0 - tau) * target_param.data) def load(self, filename): self.predicter.load_state_dict(torch.load(filename + "_predicter.pth")) self.optimizer_pre.load_state_dict(torch.load(filename + "_predicter_optimizer.pth")) self.R_local.load_state_dict(torch.load(filename + "_r_net.pth")) self.qnetwork_local.load_state_dict(torch.load(filename + "_q_net.pth")) print("Load models to {}".format(filename)) def save(self, filename): """ """ mkdir("", filename) torch.save(self.predicter.state_dict(), filename + "_predicter.pth") torch.save(self.optimizer_pre.state_dict(), filename + "_predicter_optimizer.pth") torch.save(self.qnetwork_local.state_dict(), filename + "_q_net.pth") torch.save(self.optimizer.state_dict(), filename + "_q_net_optimizer.pth") torch.save(self.R_local.state_dict(), filename + "_r_net.pth") torch.save(self.q_shift_local.state_dict(), filename + "_q_shift_net.pth") print("save models to {}".format(filename)) def test_q_value(self, memory): test_elements = memory.idx all_diff = 0 error = True used_elements_r = 0 used_elements_q = 0 r_error = 0 q_error = 0 for i in range(test_elements): states = memory.obses[i] actions = memory.actions[i] states = torch.as_tensor(states, device=self.device).unsqueeze(0) actions = torch.as_tensor(actions, device=self.device) one_hot = torch.Tensor([0 for i in range(self.action_size)], device="cpu") one_hot[actions.item()] = 1 with torch.no_grad(): r_values = self.R_local(states) q_values1, q_values2 = self.qnetwork_local(states) q_values = torch.min(q_values1, q_values2) soft_r = F.softmax(r_values, dim=1).to("cpu") soft_q = F.softmax(q_values, dim=1).to("cpu") actions = actions.type(torch.int64) kl_q = F.kl_div(soft_q.log(), one_hot, None, None, 'sum') kl_r = F.kl_div(soft_r.log(), one_hot, None, None, 'sum') if kl_r == float("inf"): pass else: r_error += kl_r used_elements_r += 1 if kl_q == float("inf"): pass else: q_error += kl_q used_elements_q += 1 average_q_kl = q_error / used_elements_q average_r_kl = r_error / used_elements_r text = "Kl div of Reward {} of {} elements".format(average_q_kl, used_elements_r) print(text) text = "Kl div of Q_values {} of {} elements".format(average_r_kl, used_elements_q) print(text) self.writer.add_scalar('KL_reward', average_r_kl, self.steps) self.writer.add_scalar('KL_q_values', average_q_kl, self.steps) def act(self, state): with torch.no_grad(): state = torch.FloatTensor(state).to(self.device).unsqueeze(0) action_prob, _ = self.policy(state) action = torch.argmax(action_prob) action = action.cpu().numpy() return action def eval_policy(self, record=False, eval_episodes=4): if record: env = wrappers.Monitor(self.env, str(self.vid_path) + "/{}".format(self.steps), video_callable=lambda episode_id: True, force=True) else: env = self.env average_reward = 0 scores_window = deque(maxlen=100) s = 0 for i_epiosde in range(eval_episodes): episode_reward = 0 state = env.reset() while True: s += 1 action = self.act(state) state, reward, done, _ = env.step(action) episode_reward += reward if done: break scores_window.append(episode_reward) if record: return average_reward = np.mean(scores_window) print("Eval Episode {} average Reward {} ".format(eval_episodes, average_reward)) self.writer.add_scalar('Eval_reward', average_reward, self.steps)
test_loss /= len(dataset_loader) return { 'epoch': every_epoch, 'average_loss': test_loss, 'correct': corrcet, 'total': len(dataset_loader.dataset), 'accuracy': 100. * float(corrcet) / len(dataset_loader.dataset) } if __name__ == '__main__': training_sta = [] test_s_sta = [] test_t_sta = [] for epoch in range(total_epochs): feature_extrator.train() class_classifier.train() start_steps = epoch * len(source_loader) for index, (source, target) in enumerate(zip(source_loader, target_loader)): p = float(index + start_steps) / total_steps res = train(feature_extrator, class_classifier, source,target, optimizer, index + start_steps) training_sta.append(res) test_source = test(feature_extrator,class_classifier, s_test_loader, epoch) test_target = test(feature_extrator, class_classifier, t_test_loader, epoch) test_s_sta.append(test_source) test_t_sta.append(test_target) print('###Test Source: Epoch: {}, avg_loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)'.format( epoch + 1, test_source['average_loss'], test_source['correct'],
def main(): args = parser.parse_args() # model model = Classifier(args.channels) optimizer = optim.SGD( model.parameters(), lr=0.05, momentum=0.9, weight_decay=0.0001, nesterov=True) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epoch) if args.gpu is not None: model.cuda(args.gpu) # dataset raw_loader = torch.utils.data.DataLoader( Dataset(os.path.join(DATA_DIR, 'raw')), args.batch // 2, shuffle=True, drop_last=True) noised_loader = torch.utils.data.DataLoader( Dataset(os.path.join(DATA_DIR, 'noised_tgt')), args.batch // 2, shuffle=True, drop_last=True) # train for epoch in range(args.epoch): loss = 0 accuracy = 0 count = 0 for x0, x1 in zip(noised_loader, raw_loader): if args.gpu is not None: x0 = x0.cuda(args.gpu) x1 = x1.cuda(args.gpu) # train model.train() x = torch.cat((x0, x1), dim=0) # @UndefinedVariable t = torch.zeros((x.shape[0], 2), device=x.device).float() # @UndefinedVariable t[:x0.shape[0], 0] = 1 t[x0.shape[0]:, 1] = 1 x, t = mixup(x, t) y = model(x) e = (-1 * nn.functional.log_softmax(y, dim=1) * t).sum(dim=1).mean() optimizer.zero_grad() e.backward() optimizer.step() # validate model.eval() with torch.no_grad(): y0 = (model(x0).max(dim=1)[1] == 0).float() y1 = (model(x1).max(dim=1)[1] == 1).float() a = torch.cat((y0, y1), dim=0).mean() # @UndefinedVariable loss += float(e) * len(x) accuracy += float(a) * len(x) count += len(x) print('[{}] lr={:.7f}, loss={:.4f}, accuracy={:.4f}'.format( epoch, float(optimizer.param_groups[0]['lr']), loss / count, accuracy / count), flush=True) scheduler.step() snapshot = {'channels': args.channels, 'model': model.state_dict()} torch.save(snapshot, '{}.tmp'.format(args.file)) os.rename('{}.tmp'.format(args.file), args.file)
testloader = data.DataLoader( datasets.MNIST('../data', download=True, train=False, transform=transforms.Compose([ transforms.ToTensor(), ])), batch_size=100, shuffle=True) def accuracy(): total_correct = 0.0 total = 0.0 with torch.no_grad(): for images, targets in testloader: out = cla(images.cuda()) preds = out.argmax(1) total_correct += (preds.cpu()==targets).float().sum() total += preds.shape[0] return total_correct/total loss_fn = nn.CrossEntropyLoss() for _ in range(5): cla.train(True) for images, targets in trainloader: opt.zero_grad() out = cla(images.cuda()) loss = loss_fn(out, targets.cuda()) loss.backward() opt.step() print("Loss: %f" % (loss,)) cla.eval() print(accuracy()) torch.save(cla, "ckpts/classifier")
class MLCModel: """Summary Attributes: cfg (TYPE): Description criterion (TYPE): Description device (TYPE): cpu or gpu hparams (TYPE): hyper parameters from parser labels (TYPE): list of the diseases, see init_labels() model (TYPE): feature extraction backbone with classifier, see cfg.json names (TYPE): list of filenames in the images which have been from dataloader num_tasks (TYPE): 5 or 14, number of diseases """ def __init__(self, hparams): """Summary Args: hparams (TYPE): hyper parameters from parser """ super(MLCModel, self).__init__() self.hparams = hparams self.device = torch.device("cuda:{}".format(hparams.gpus) if torch. cuda.is_available() else "cpu") with open(self.hparams.json_path, 'r') as f: self.cfg = edict(json.load(f)) hparams_dict = vars(self.hparams) self.cfg['hparams'] = hparams_dict if self.hparams.verbose is True: print(json.dumps(self.cfg, indent=4)) if self.cfg.criterion in ['bce', 'focal', 'sce', 'bce_v2', 'bfocal']: self.criterion = init_loss_func(self.cfg.criterion, device=self.device) elif self.cfg.criterion == 'class_balance': samples_per_cls = list( map(int, self.cfg.samples_per_cls.split(','))) self.criterion = init_loss_func(self.cfg.criterion, samples_per_cls=samples_per_cls, loss_type=self.cfg.loss_type) else: self.criterion = init_loss_func(self.cfg.criterion) self.labels = init_labels(name=self.hparams.data_name) if self.cfg.extract_fields is None: self.cfg.extract_fields = ','.join( [str(idx) for idx in range(len(self.labels))]) else: assert isinstance(self.cfg.extract_fields, str), "extract_fields must be string!" self.model = Classifier(self.cfg, self.hparams) self.state_dict = None # Load cross-model from other configuration if self.hparams.load is not None and len(self.hparams.load) > 0: if not os.path.exists(hparams.load): raise ValueError('{} does not exists!'.format(hparams.load)) state_dict = load_state_dict(self.hparams.load, self.model, self.device) self.state_dict = state_dict # DataParallel model if torch.cuda.device_count() > 1 and self.hparams.gpus == 0: self.model = nn.DataParallel(self.model) self.model.to(device=self.device) self.num_tasks = list(map(int, self.cfg.extract_fields.split(','))) self.names = list() self.optimizer, self.scheduler = self.configure_optimizers() self.train_loader = self.train_dataloader() self.valid_loader = self.val_dataloader() self.test_loader = self.test_dataloader() def forward(self, x): """Summary Args: x (TYPE): image Returns: TYPE: Description """ return self.model(x) def train(self): epoch_start = 0 summary_train = { 'epoch': 0, 'step': 0, 'total_step': len(self.train_loader) } summary_dev = {'loss': float('inf'), 'score': 0.0} best_dict = { "score_dev_best": 0.0, "loss_dev_best": float('inf'), "score_top_k": [0.0], "loss_top_k": [0.0], "score_curr_idx": 0, "loss_curr_idx": 0 } if self.state_dict is not None: summary_train = { 'epoch': self.state_dict['epoch'], 'step': self.state_dict['step'], 'total_step': len(self.train_loader) } best_dict['score_dev_best'] = self.state_dict['score_dev_best'] best_dict['loss_dev_best'] = self.state_dict['loss_dev_best'] epoch_start = self.state_dict['epoch'] for epoch in range(epoch_start, self.hparams.epochs): lr = self.create_scheduler(start_epoch=summary_train['epoch']) for param_group in self.optimizer.param_groups: param_group['lr'] = lr logging.info('Learning rate in epoch {}: {}'.format( epoch + 1, self.optimizer.param_groups[0]['lr'])) print('Learning rate in epoch {}: {}'.format( epoch + 1, self.optimizer.param_groups[0]['lr'])) summary_train, best_dict = self.training_step( summary_train, summary_dev, best_dict) self.validation_end(summary_dev, summary_train, best_dict) torch.save( { 'epoch': summary_train['epoch'], 'step': summary_train['step'], 'score_dev_best': best_dict['score_dev_best'], 'loss_dev_best': best_dict['loss_dev_best'], 'state_dict': self.model.state_dict() }, os.path.join(self.hparams.save_path, '{}_model.pth'.format(summary_train['epoch'] - 1))) logging.info('Training finished, model saved') print('Training finished, model saved') # def training_step(self, batch, batch_nb): def training_step(self, summary_train, summary_dev, best_dict): """Summary Extract the batch of datapoints and return the predicted logits Args: summary_train: summary_dev: best_dict: Returns: TYPE: Description """ losses = AverageMeter() torch.set_grad_enabled(True) self.model.train() time_now = time.time() for i, (inputs, target, _) in enumerate(self.train_loader): if isinstance(inputs, tuple): inputs = tuple([ e.to(self.device) if type(e) == torch.Tensor else e for e in inputs ]) else: inputs = inputs.to(self.device) target = target.to(self.device) self.optimizer.zero_grad() if self.cfg.no_jsd: if self.cfg.n_crops: bs, n_crops, c, h, w = inputs.size() inputs = inputs.view(-1, c, h, w) if len(self.hparams.mixtype) > 0: if self.hparams.multi_cls: target = target.view(target.size()[0], -1) inputs, targets_a, targets_b, lam = self.mix_data( inputs, target.repeat(1, n_crops).view(-1), self.device, self.hparams.alpha) else: inputs, targets_a, targets_b, lam = self.mix_data( inputs, target.repeat(1, n_crops).view( -1, len(self.num_tasks)), self.device, self.hparams.alpha) logits = self.forward(inputs) if len(self.hparams.mixtype) > 0: loss_func = self.mixup_criterion( targets_a, targets_b, lam) loss = loss_func(self.criterion, logits) else: if self.hparams.multi_cls: target = target.view(target.size()[0], -1) loss = self.criterion( logits, target.repeat(1, n_crops).view(-1)) else: loss = self.criterion( logits, target.repeat(1, n_crops).view( -1, len(self.num_tasks))) else: if len(self.hparams.mixtype) > 0: inputs, targets_a, targets_b, lam = self.mix_data( inputs, target, self.device, self.hparams.alpha) logits = self.forward(inputs) if len(self.hparams.mixtype) > 0: loss_func = self.mixup_criterion( targets_a, targets_b, lam) loss = loss_func(self.criterion, logits) else: loss = self.criterion(logits, target) else: images_all = torch.cat(inputs, 0) logits_all = self.forward(images_all) logits_clean, logits_aug1, logits_aug2 = torch.split( logits_all, inputs[0].size(0)) # Cross-entropy is only computed on clean images loss = F.cross_entropy(logits_clean, target) p_clean, p_aug1, p_aug2 = F.softmax( logits_clean, dim=1), F.softmax(logits_aug1, dim=1), F.softmax(logits_aug2, dim=1) # Clamp mixture distribution to avoid exploding KL divergence p_mixture = torch.clamp((p_clean + p_aug1 + p_aug2) / 3., 1e-7, 1).log() loss += 12 * ( F.kl_div(p_mixture, p_clean, reduction='batchmean') + F.kl_div(p_mixture, p_aug1, reduction='batchmean') + F.kl_div(p_mixture, p_aug2, reduction='batchmean')) / 3. assert not np.isnan( loss.item()), 'Model diverged with losses = NaN' loss.backward() self.optimizer.step() summary_train['step'] += 1 losses.update(loss.item(), target.size(0)) if summary_train['step'] % self.hparams.log_every == 0: time_spent = time.time() - time_now time_now = time.time() logging.info('Train, ' 'Epoch : {}, ' 'Step : {}/{}, ' 'Loss: {loss.val:.4f} ({loss.avg:.4f}), ' 'Run Time : {runtime:.2f} sec'.format( summary_train['epoch'] + 1, summary_train['step'], summary_train['total_step'], loss=losses, runtime=time_spent)) print('Train, ' 'Epoch : {}, ' 'Step : {}/{}, ' 'Loss: {loss.val:.4f} ({loss.avg:.4f}), ' 'Run Time : {runtime:.2f} sec'.format( summary_train['epoch'] + 1, summary_train['step'], summary_train['total_step'], loss=losses, runtime=time_spent)) if summary_train['step'] % self.hparams.test_every == 0: self.validation_end(summary_dev, summary_train, best_dict) self.model.train() torch.set_grad_enabled(True) summary_train['epoch'] += 1 return summary_train, best_dict def validation_step(self, summary_dev): """Summary Extract the batch of datapoints and return the predicted logits in validation step Args: summary_dev (TYPE): Description Returns: TYPE: Description """ losses = AverageMeter() torch.set_grad_enabled(False) self.model.eval() output_ = np.array([]) target_ = np.array([]) with torch.no_grad(): for i, (inputs, target, _) in enumerate(self.valid_loader): target = target.to(self.device) if isinstance(inputs, tuple): inputs = tuple([ e.to(self.device) if type(e) == torch.Tensor else e for e in inputs ]) else: inputs = inputs.to(self.device) logits = self.forward(inputs) loss = self.criterion(logits, target) losses.update(loss.item(), target.size(0)) if self.hparams.multi_cls: output = F.softmax(logits) _, output = torch.max(output, 1) else: output = torch.sigmoid(logits) target = target.detach().to('cpu').numpy() target_ = np.concatenate( (target_, target), axis=0) if len(target_) > 0 else target y_pred = output.detach().to('cpu').numpy() output_ = np.concatenate( (output_, y_pred), axis=0) if len(output_) > 0 else y_pred summary_dev['loss'] = losses.avg return summary_dev, output_, target_ def validation_end(self, summary_dev, summary_train, best_dict): """Summary After the validation end, calculate the metrics Args: summary_dev (TYPE): Description summary_train (TYPE): Description best_dict (TYPE): Description Returns: TYPE: Description """ time_now = time.time() summary_dev, output_, target_ = self.validation_step(summary_dev) time_spent = time.time() - time_now if not self.hparams.auto_threshold: overall_pre, overall_rec, overall_fscore = get_metrics( copy.deepcopy(output_), target_, self.cfg.beta, self.cfg.threshold, self.cfg.metric_type) else: overall_pre, overall_rec, overall_fscore = self.find_best_fixed_threshold( output_, target_) resp = dict() if not self.hparams.multi_cls: for t in range(len(self.num_tasks)): y_pred = np.transpose(output_)[t] precision, recall, f_score = get_metrics( copy.deepcopy(y_pred), np.transpose(target_)[t], self.cfg.beta, self.cfg.threshold, 'binary') resp['precision_{}'.format( self.labels[self.num_tasks[t]])] = precision resp['recall_{}'.format( self.labels[self.num_tasks[t]])] = recall resp['f_score_{}'.format( self.labels[self.num_tasks[t]])] = f_score resp['overall_precision'] = overall_pre resp['overall_recall'] = overall_rec resp['overall_f_score'] = overall_fscore logging.info( 'Dev, Step : {}/{}, Loss : {}, Fscore : {:.3f}, Precision : {:.3f}, ' 'Recall : {:.3f}, Run Time : {:.2f} sec'.format( summary_train['step'], summary_train['total_step'], summary_dev['loss'], resp['overall_f_score'], resp['overall_precision'], resp['overall_recall'], time_spent)) print( 'Dev, Step : {}/{}, Loss : {}, Fscore : {:.3f}, Precision : {:.3f}, ' 'Recall : {:.3f}, Run Time : {:.2f} sec'.format( summary_train['step'], summary_train['total_step'], summary_dev['loss'], resp['overall_f_score'], resp['overall_precision'], resp['overall_recall'], time_spent)) save_best = False mean_score = resp['overall_f_score'] if mean_score > min(best_dict['score_top_k']): self.update_top_k(mean_score, best_dict, 'score') if self.hparams.metric == 'score': save_best = True mean_loss = summary_dev['loss'] if mean_loss < max(best_dict['loss_top_k']): self.update_top_k(mean_loss, best_dict, 'loss') if self.hparams.metric == 'loss': save_best = True if save_best: torch.save( { 'epoch': summary_train['epoch'], 'step': summary_train['step'], 'score_dev_best': best_dict['score_dev_best'], 'loss_dev_best': best_dict['loss_dev_best'], 'state_dict': self.model.state_dict() }, os.path.join(self.hparams.save_path, 'best{}.pth'.format(best_dict['score_curr_idx']))) logging.info( 'Best {}, Step : {}/{}, Loss : {}, Score : {:.3f}'.format( best_dict['score_curr_idx'], summary_train['step'], summary_train['total_step'], summary_dev['loss'], best_dict['score_dev_best'])) print('Best {}, Step : {}/{}, Loss : {}, Score : {:.3f}'.format( best_dict['score_curr_idx'], summary_train['step'], summary_train['total_step'], summary_dev['loss'], best_dict['score_dev_best'])) def find_best_fixed_threshold(self, output_, target_): score = list() thrs = np.arange(0, 1.0, 0.01) pre_rec = list() for thr in tqdm.tqdm(thrs): pre, rec, fscore = get_metrics(copy.deepcopy(output_), copy.deepcopy(target_), self.cfg.beta, thr, self.cfg.metric_type) score.append(fscore) pre_rec.append([pre, rec]) score = np.array(score) pm = score.argmax() best_thr, best_score = thrs[pm], score[pm].item() best_pre, best_rec = pre_rec[pm] print('thr={} F2={} prec{} rec{}'.format(best_thr, best_score, best_pre, best_rec)) return best_pre, best_rec, best_score def test_step(self): """Summary Extract the batch of datapoints and return the predicted logits in test step Args: batch (TYPE): Description batch_nb (TYPE): Description Returns: TYPE: Description """ torch.set_grad_enabled(False) self.model.eval() output_ = np.array([]) target_ = np.array([]) with torch.no_grad(): for i, batch in enumerate(self.test_loader): if self.hparams.infer == 'valid': # Evaluate inputs, target, names = batch target = target.to(self.device) else: # Test inputs, names = batch if isinstance(inputs, tuple): inputs = tuple([ e.to(self.device) if type(e) == torch.Tensor else e for e in inputs ]) else: inputs = inputs.to(self.device) self.names.extend(names) if self.cfg.n_crops: bs, n_crops, c, h, w = inputs.size() inputs = inputs.view(-1, c, h, w) logits = self.forward(inputs) if self.cfg.n_crops: logits = logits.view(bs, n_crops, -1).mean(1) if self.hparams.multi_cls: output = F.softmax(logits) output = output[:, 1] else: output = torch.sigmoid(logits) if self.hparams.infer == 'valid': target = target.detach().to('cpu').numpy() target_ = np.concatenate((target_, target), axis=0) if len(target_) > 0 else \ target y_pred = output.detach().to('cpu').numpy() output_ = np.concatenate( (output_, y_pred), axis=0) if len(output_) > 0 else y_pred if self.hparams.infer == 'valid': return output_, target_ else: return output_ def test(self): """Summary After the test end, calculate the metrics Args: Returns: TYPE: Description """ # inference dataset if self.hparams.infer == 'valid': output_, target_ = self.test_step() else: output_ = self.test_step() resp = dict() to_csv = {'Images': self.names} for t in range(len(self.num_tasks)): if self.hparams.multi_cls: y_pred = np.reshape(output_, output_.shape[0]) else: y_pred = np.transpose(output_)[t] to_csv[self.labels[self.num_tasks[t]]] = y_pred # Only save scores to json file when in valid mode if self.hparams.infer == 'valid': overall_pre, overall_rec, overall_fscore = get_metrics( copy.deepcopy(output_), copy.deepcopy(target_), self.cfg.beta, self.cfg.threshold, self.cfg.metric_type) resp['overall_pre'] = overall_pre resp['overall_rec'] = overall_rec resp['overall_f_score'] = overall_fscore with open( os.path.join(os.path.dirname(self.hparams.load), 'scores_{}.csv'.format(uuid.uuid4())), 'w') as f: json.dump(resp, f) # Save predictions to csv file for computing metrics in off-line mode path_df = DataFrame(to_csv, columns=to_csv.keys()) path_df.to_csv(os.path.join(os.path.dirname(self.hparams.load), 'predictions_{}.csv'.format(uuid.uuid4())), index=False) return resp def configure_optimizers(self): """Summary Must be implemented Returns: TYPE: Description """ optimizer = create_optimizer(self.cfg, self.model.parameters()) if self.cfg.lr_scheduler == 'step': scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=self.cfg.step_size, gamma=self.cfg.lr_factor) elif self.cfg.lr_scheduler == 'cosin': scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200, eta_min=1e-6) elif self.cfg.lr_scheduler == 'cosin_epoch': scheduler = optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=self.cfg.tmax, eta_min=self.cfg.eta_min) elif self.cfg.lr_scheduler == 'onecycle': max_lr = [g["lr"] for g in optimizer.param_groups] scheduler = optim.lr_scheduler.OneCycleLR( optimizer, max_lr=max_lr, epochs=self.hparams.epochs, steps_per_epoch=len(self.train_dataloader())) scheduler = {"scheduler": scheduler, "interval": "step"} else: raise ValueError( 'Does not support {} learning rate scheduler'.format( self.cfg.lr_scheduler)) return optimizer, scheduler def train_dataloader(self): """Summary Return the train dataset, see dataflow/__init__.py Returns: TYPE: Description """ ds_train = init_dataset(self.hparams.data_name, cfg=self.cfg, data_path=self.hparams.data_path, mode='train') return DataLoader(dataset=ds_train, batch_size=self.cfg.train_batch_size, shuffle=True, num_workers=self.hparams.num_workers, pin_memory=True) def val_dataloader(self): """Summary Return the val dataset, see dataflow/__init__.py Returns: TYPE: Description """ ds_val = init_dataset(self.hparams.data_name, cfg=self.cfg, data_path=self.hparams.data_path, mode='valid') return DataLoader(dataset=ds_val, batch_size=self.cfg.dev_batch_size, shuffle=False, num_workers=self.hparams.num_workers, pin_memory=True) def test_dataloader(self): """Summary Return the test dataset, see dataflow/__init__.py Returns: TYPE: Description """ ds_test = init_dataset(self.hparams.data_name, cfg=self.cfg, data_path=self.hparams.data_path, mode=self.hparams.infer) return DataLoader(dataset=ds_test, batch_size=self.cfg.dev_batch_size, shuffle=False, num_workers=self.hparams.num_workers, pin_memory=True) def mix_data(self, x, y, device, alpha=1.0): """ Re-constructed input images and labels based on one of two regularization methods such as Mixup and Cutmix. :param x: input images :param y: labels :param device: cpu or gpu device :param alpha: parameter for beta distribution :return: mixed inputs, pairs of targets, and lambda """ if alpha > 0.: lam = np.random.beta(alpha, alpha) else: lam = 1. batch_size = x.size()[0] index = torch.randperm(batch_size).to(device) y_a, y_b = y, y[index] if self.hparams.mixtype == 'mixup': mixed_x = lam * x + (1 - lam) * x[index, :] elif self.hparams.mixtype == 'cutmix': bbx1, bby1, bbx2, bby2 = self.rand_bbox(x.size(), lam) x[:, :, bbx1:bbx2, bby1:bby2] = x[index, :, bbx1:bbx2, bby1:bby2] mixed_x = x lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size()[-1] * x.size()[-2])) else: raise ValueError('Mixtype {} does not exists'.format( self.hparams.mixtype)) return mixed_x, y_a, y_b, lam def create_scheduler(self, start_epoch): """ Learning rate schedule with respect to epoch lr: float, initial learning rate lr_factor: float, decreasing factor every epoch_lr epoch_now: int, the current epoch lr_epochs: list of int, decreasing every epoch in lr_epochs return: lr, float, scheduled learning rate. """ count = 0 for epoch in self.hparams.lr_epochs.split(','): if start_epoch >= int(epoch): count += 1 continue break return self.cfg.lr * np.power(self.cfg.lr_factor, count) @staticmethod def mixup_criterion(y_a, y_b, lam): """ Re-constructured loss function based on regularization technique Args: y_a: original labels y_b: shuffled labels after random permutation lam: generated point in beta distribution Returns: Combined loss function """ return lambda criterion, pred: lam * criterion(pred, y_a) + ( 1 - lam) * criterion(pred, y_b) @staticmethod def rand_bbox(size, lam): """ Generate random bounding box for specified cutting rate Args: size: image size including weight and height lam: generated point in beta distribution Returns: Coordinates of top-left and right-bottom vertices of bounding box """ W = size[2] H = size[3] cut_rat = np.sqrt(1. - lam) cut_w = np.int(W * cut_rat) cut_h = np.int(H * cut_rat) # uniform cx = np.random.randint(W) cy = np.random.randint(H) bbx1 = np.clip(cx - cut_w // 2, 0, W) bby1 = np.clip(cy - cut_h // 2, 0, H) bbx2 = np.clip(cx + cut_w // 2, 0, W) bby2 = np.clip(cy + cut_h // 2, 0, H) return bbx1, bby1, bbx2, bby2 def update_top_k(self, mean, best_dict, metric): metric_dev_best = '{}_dev_best'.format(metric) metric_top_k = '{}_top_k'.format(metric) metric_curr_idx = '{}_curr_idx'.format(metric) if metric == 'loss': if mean < best_dict[metric_dev_best]: best_dict[metric_dev_best] = mean else: if mean > best_dict[metric_dev_best]: best_dict[metric_dev_best] = mean if len(best_dict[metric_top_k]) >= self.hparams.save_top_k: if metric == 'loss': min_idx = best_dict[metric_top_k].index( max(best_dict[metric_top_k])) else: min_idx = best_dict[metric_top_k].index( min(best_dict[metric_top_k])) curr_idx = min_idx best_dict[metric_top_k][min_idx] = mean else: curr_idx = len(best_dict[metric_top_k]) best_dict[metric_top_k].append(mean) best_dict[metric_curr_idx] = curr_idx
def train_sentiment(opts): device = torch.device("cuda" if use_cuda else "cpu") glove_loader = GloveLoader(os.path.join(opts.data_dir, 'glove', opts.glove_emb_file)) train_loader = DataLoader(RottenTomatoesReviewDataset(opts.data_dir, 'train', glove_loader, opts.maxlen), \ batch_size=opts.bsize, shuffle=True, num_workers=opts.nworkers) valid_loader = DataLoader(RottenTomatoesReviewDataset(opts.data_dir, 'val', glove_loader, opts.maxlen), \ batch_size=opts.bsize, shuffle=False, num_workers=opts.nworkers) model = Classifier(opts.hidden_size, opts.dropout_p, glove_loader, opts.enc_arch) if opts.optim == 'adam': optimizer = torch.optim.Adam(model.parameters(), lr=opts.lr, weight_decay=opts.wd) else: raise NotImplementedError("Unknown optim type") criterion = nn.CrossEntropyLoss() start_n_iter = 0 # for choosing the best model best_val_acc = 0.0 model_path = os.path.join(opts.save_path, 'model_latest.net') if opts.resume and os.path.exists(model_path): # restoring training from save_state print ('====> Resuming training from previous checkpoint') save_state = torch.load(model_path, map_location='cpu') model.load_state_dict(save_state['state_dict']) start_n_iter = save_state['n_iter'] best_val_acc = save_state['best_val_acc'] opts = save_state['opts'] opts.start_epoch = save_state['epoch'] + 1 model = model.to(device) # for logging logger = TensorboardXLogger(opts.start_epoch, opts.log_iter, opts.log_dir) logger.set(['acc', 'loss']) logger.n_iter = start_n_iter for epoch in range(opts.start_epoch, opts.epochs): model.train() logger.step() for batch_idx, data in enumerate(train_loader): acc, loss = run_iter(opts, data, model, criterion, device) # optimizer step optimizer.zero_grad() loss.backward() nn.utils.clip_grad_norm_(model.parameters(), opts.max_norm) optimizer.step() logger.update(acc, loss) val_loss, val_acc, time_taken = evaluate(opts, model, valid_loader, criterion, device) # log the validation losses logger.log_valid(time_taken, val_acc, val_loss) print ('') # Save the model to disk if val_acc >= best_val_acc: best_val_acc = val_acc save_state = { 'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'n_iter': logger.n_iter, 'opts': opts, 'val_acc': val_acc, 'best_val_acc': best_val_acc } model_path = os.path.join(opts.save_path, 'model_best.net') torch.save(save_state, model_path) save_state = { 'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'n_iter': logger.n_iter, 'opts': opts, 'val_acc': val_acc, 'best_val_acc': best_val_acc } model_path = os.path.join(opts.save_path, 'model_latest.net') torch.save(save_state, model_path)
def train(seed=0, dataset='grid', samplers=(UniformDatasetSampler, UniformLatentSampler), latent_dim=2, model_dim=256, device='cuda', conditional=False, learning_rate=2e-4, betas=(0.5, 0.9), batch_size=256, iterations=400, n_critic=5, objective='gan', gp_lambda=10, output_dir='results', plot=False, spec_norm=True): experiment_name = [ seed, dataset, samplers[0].__name__, samplers[1].__name__, latent_dim, model_dim, device, conditional, learning_rate, betas[0], betas[1], batch_size, iterations, n_critic, objective, gp_lambda, plot, spec_norm ] experiment_name = '_'.join([str(p) for p in experiment_name]) results_dir = os.path.join(output_dir, experiment_name) network_dir = os.path.join(results_dir, 'networks') eval_log = os.path.join(results_dir, 'eval.log') os.makedirs(results_dir, exist_ok=True) os.makedirs(network_dir, exist_ok=True) eval_file = open(eval_log, 'w') if plot: samples_dir = os.path.join(results_dir, 'samples') os.makedirs(samples_dir, exist_ok=True) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) data, labels = load_data(dataset) data_dim, num_classes = data.shape[1], len(set(labels)) data_sampler = samplers[0]( torch.tensor(data).float(), torch.tensor(labels).long()) if conditional else samplers[0]( torch.tensor(data).float()) noise_sampler = samplers[1]( latent_dim, labels) if conditional else samplers[1](latent_dim) if conditional: test_data, test_labels = load_data(dataset, split='test') test_dataset = TensorDataset( torch.tensor(test_data).to(device).float(), torch.tensor(test_labels).to(device).long()) test_dataloader = DataLoader(test_dataset, batch_size=4096) G = Generator(latent_dim + num_classes, model_dim, data_dim).to(device).train().train() D = Discriminator(model_dim, data_dim + num_classes, spec_norm=spec_norm).to(device).train() C_real = Classifier(model_dim, data_dim, num_classes).to(device).train() C_fake = Classifier(model_dim, data_dim, num_classes).to(device).train() C_fake.load_state_dict(deepcopy(C_real.state_dict())) C_real_optimizer = optim.Adam(C_real.parameters(), lr=2 * learning_rate) C_fake_optimizer = optim.Adam(C_fake.parameters(), lr=2 * learning_rate) C_crit = nn.CrossEntropyLoss() else: G = Generator(latent_dim, model_dim, data_dim).to(device).train() D = Discriminator(model_dim, data_dim, spec_norm=spec_norm).to(device).train() D_optimizer = optim.Adam(D.parameters(), lr=learning_rate, betas=betas) G_optimizer = optim.Adam(G.parameters(), lr=learning_rate, betas=betas) if objective == 'gan': fake_target = torch.zeros(batch_size, 1).to(device) real_target = torch.ones(batch_size, 1).to(device) elif objective == 'wgan': grad_target = torch.ones(batch_size, 1).to(device) elif objective == 'hinge': bound = torch.zeros(batch_size, 1).to(device) sub = torch.ones(batch_size, 1).to(device) stats = {'D': [], 'G': [], 'C_it': [], 'C_real': [], 'C_fake': []} if plot: fixed_latent_batch = noise_sampler.get_batch(20000) sample_figure = plt.figure(num=0, figsize=(5, 5)) loss_figure = plt.figure(num=1, figsize=(10, 5)) if conditional: accuracy_figure = plt.figure(num=2, figsize=(10, 5)) for it in range(iterations + 1): # Train Discriminator data_batch = data_sampler.get_batch(batch_size) latent_batch = noise_sampler.get_batch(batch_size) if conditional: x_real, y_real = data_batch[0].to(device), data_batch[1].to(device) real_sample = torch.cat([x_real, y_real], dim=1) z_fake, y_fake = latent_batch[0].to(device), latent_batch[1].to( device) x_fake = G(torch.cat([z_fake, y_fake], dim=1)).detach() fake_sample = torch.cat([x_fake, y_fake], dim=1) else: x_real = data_batch.to(device) real_sample = x_real z_fake = latent_batch.to(device) x_fake = G(z_fake).detach() fake_sample = x_fake D.zero_grad() real_pred = D(real_sample) fake_pred = D(fake_sample) if is_recorded(data_sampler): data_sampler.record(real_pred.detach().cpu().numpy()) if is_weighted(data_sampler): weights = torch.tensor( data_sampler.get_weights()).to(device).float().view( real_pred.shape) else: weights = torch.ones_like(real_pred).to(device) if objective == 'gan': D_loss = F.binary_cross_entropy(fake_pred, fake_target).mean() + ( weights * F.binary_cross_entropy(real_pred, real_target)).mean() stats['D'].append(D_loss.item()) elif objective == 'wgan': alpha = torch.rand(batch_size, 1).expand(real_sample.size()).to(device) interpolate = (alpha * real_sample + (1 - alpha) * fake_sample).requires_grad_(True) gradients = torch.autograd.grad(outputs=D(interpolate), inputs=interpolate, grad_outputs=grad_target, create_graph=True, retain_graph=True, only_inputs=True)[0] gradient_penalty = (gradients.norm(2, dim=1) - 1).pow(2).mean() * gp_lambda D_loss = fake_pred.mean() - (real_pred * weights).mean() stats['D'].append(-D_loss.item()) D_loss += gradient_penalty elif objective == 'hinge': D_loss = -(torch.min(real_pred - sub, bound) * weights).mean() - torch.min(-fake_pred - sub, bound).mean() stats['D'].append(D_loss.item()) D_loss.backward() D_optimizer.step() # Train Generator if it % n_critic == 0: G.zero_grad() latent_batch = noise_sampler.get_batch(batch_size) if conditional: z_fake, y_fake = latent_batch[0].to( device), latent_batch[1].to(device) x_fake = G(torch.cat([z_fake, y_fake], dim=1)) fake_pred = D(torch.cat([x_fake, y_fake], dim=1)) else: z_fake = latent_batch.to(device) x_fake = G(z_fake) fake_pred = D(x_fake) if objective == 'gan': G_loss = F.binary_cross_entropy(fake_pred, real_target).mean() stats['G'].extend([G_loss.item()] * n_critic) elif objective == 'wgan': G_loss = -fake_pred.mean() stats['G'].extend([-G_loss.item()] * n_critic) elif objective == 'hinge': G_loss = -fake_pred.mean() stats['G'].extend([-G_loss.item()] * n_critic) G_loss.backward() G_optimizer.step() if conditional: # Train fake classifier C_fake.train() C_fake.zero_grad() C_fake_loss = C_crit(C_fake(x_fake.detach()), y_fake.argmax(1)) C_fake_loss.backward() C_fake_optimizer.step() # Train real classifier C_real.train() C_real.zero_grad() C_real_loss = C_crit(C_real(x_real), y_real.argmax(1)) C_real_loss.backward() C_real_optimizer.step() if it % 5 == 0: C_real.eval() C_fake.eval() real_correct, fake_correct, total = 0.0, 0.0, 0.0 for idx, (sample, label) in enumerate(test_dataloader): real_correct += ( C_real(sample).argmax(1).view(-1) == label).sum() fake_correct += ( C_fake(sample).argmax(1).view(-1) == label).sum() total += sample.shape[0] stats['C_it'].append(it) stats['C_real'].append(real_correct.item() / total) stats['C_fake'].append(fake_correct.item() / total) line = f"{it}\t{stats['D'][-1]:.3f}\t{stats['G'][-1]:.3f}" if conditional: line += f"\t{stats['C_real'][-1]*100:.3f}\t{stats['C_fake'][-1]*100:.3f}" print(line, eval_file) if plot: if conditional: z_fake, y_fake = fixed_latent_batch[0].to( device), fixed_latent_batch[1].to(device) x_fake = G(torch.cat([z_fake, y_fake], dim=1)) else: z_fake = fixed_latent_batch.to(device) x_fake = G(z_fake) generated = x_fake.detach().cpu().numpy() plt.figure(0) plt.clf() plt.scatter(generated[:, 0], generated[:, 1], marker='.', color=(0, 1, 0, 0.01)) plt.axis('equal') plt.xlim(-1, 1) plt.ylim(-1, 1) plt.savefig(os.path.join(samples_dir, f'{it}.png')) plt.figure(1) plt.clf() plt.plot(stats['G'], label='Generator') plt.plot(stats['D'], label='Discriminator') plt.legend() plt.savefig(os.path.join(results_dir, 'loss.png')) if conditional: plt.figure(2) plt.clf() plt.plot(stats['C_it'], stats['C_real'], label='Real') plt.plot(stats['C_it'], stats['C_fake'], label='Fake') plt.legend() plt.savefig(os.path.join(results_dir, 'accuracy.png')) save_model(G, os.path.join(network_dir, 'G_trained.pth')) save_model(D, os.path.join(network_dir, 'D_trained.pth')) save_stats(stats, os.path.join(results_dir, 'stats.pth')) if conditional: save_model(C_real, os.path.join(network_dir, 'C_real_trained.pth')) save_model(C_fake, os.path.join(network_dir, 'C_fake_trained.pth')) eval_file.close()
# ------------------------ start training ------------------------ logging.info("start training") header = ' Time Epoch Iteration Progress (%Epoch) Loss Dev/Loss Accuracy Dev/Accuracy' dev_log_template = ' '.join('{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{:8.6f},{:12.4f},{:12.4f}'.split(',')) log_template = ' '.join('{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{},{:12.4f},{}'.split(',')) print(header) iterations = 0 start = time.time() for epoch in range(args.epochs): n_correct, n_total = 0, 0 train_iter.init_epoch() for batch_idx, batch in enumerate(train_iter): model.train(); opt.zero_grad() iterations += 1 predict = model(batch) n_correct += (torch.max(predict, 1)[1].view(batch.label.size()) == batch.label).sum().item() n_total += batch.batch_size train_acc = 100. * n_correct/n_total loss = criterion(predict, batch.label) loss.backward(); opt.step() if iterations % args.dev_every == 0: model.eval() dev_iter.init_epoch() n_dev_correct, dev_loss = 0, 0 with torch.no_grad(): for dev_batch_idx, dev_batch in enumerate(dev_iter): predict = model(dev_batch)