def dqn_learning(env, q_func, optimizer_spec, exploration=LinearSchedule(1000000, 0.1), max_steps=20e6, replay_buffer_size=1000000, batch_size=32, sample_size=128, gamma=0.99, beta=0.05, reg_scale=0.1, use_restart=True, learning_starts=50000, learning_freq=4, frame_history_len=4, target_update_freq=2000, save_path=None): """Run Deep Q-learning algorithm with regularized anderson acceleration. You can specify your own convnet using q_func. All schedules are w.r.t. total number of steps taken in the environment. Parameters ---------- env: gym.Env gym environment to train on. q_func: function Model to use for computing the q function. optimizer_spec: OptimizerSpec Specifying the constructor and kwargs, as well as learning rate schedule for the optimizer exploration: rl_algs.deepq.utils.schedules.Schedule schedule for probability of chosing random action. max_steps: float Maximal steps. replay_buffer_size: int How many memories to store in the replay buffer. batch_size: int How many transitions to sample each time experience is replayed. gamma: float Discount Factor learning_starts: int After how many environment steps to start replaying experiences learning_freq: int How many steps of environment to take between every experience replay frame_history_len: int How many past frames to include as input to the model. target_update_freq: int How many experience replay rounds (not steps!) to perform between each update to the target Q network grad_norm_clipping: float or None If not None gradients' norms are clipped to this value. """ assert type(env.observation_space) == gym.spaces.Box assert type(env.action_space) == gym.spaces.Discrete # Set the logger logger = Logger(save_path) ############### # BUILD MODEL # ############### if len(env.observation_space.shape) == 1: # This means we are running on low-dimensional observations (e.g. RAM) input_shape = env.observation_space.shape in_channels = input_shape[0] else: img_h, img_w, img_c = env.observation_space.shape input_shape = (img_h, img_w, frame_history_len * img_c) in_channels = input_shape[2] num_actions = env.action_space.n # define Q target and Q Q = q_func(in_channels, num_actions).to(device) Q_targets = [] MAX_NUM = 5 for i in range(MAX_NUM): Q_targets.append(q_func(in_channels, num_actions).to(device)) # initialize anderson anderson = RAA(MAX_NUM, use_restart, reg_scale) # initialize optimizer optimizer = optimizer_spec.constructor(Q.parameters(), **optimizer_spec.kwargs) # create replay buffer replay_buffer = ReplayBuffer(replay_buffer_size, frame_history_len) ###### ############### # RUN ENV # ############### num_param_updates = 0 mean_episode_reward = -float('nan') best_mean_episode_reward = -float('inf') last_obs = env.reset() LOG_EVERY_N_STEPS = 10000 SAVE_MODEL_EVERY_N_STEPS = 100000 saved_scalars = [] stop = False restart = True cur_num = 1 clipped_error = torch.FloatTensor([0]).to(device) for t in itertools.count(): # 1. Step the env and store the transition # store last frame, returned idx used later last_stored_frame_idx = replay_buffer.store_frame(last_obs) # get observations to input to Q network (need to append prev frames) observations = replay_buffer.encode_recent_observation() # torch # before learning starts, choose actions randomly if t < learning_starts: action = np.random.randint(num_actions) else: # epsilon greedy exploration sample = random.random() threshold = exploration.value(t) if sample > threshold: obs = observations.unsqueeze(0) / 255.0 with torch.no_grad(): q_value_all_actions = Q(obs) action = (q_value_all_actions.data.max(1)[1])[0] else: action = torch.IntTensor([[np.random.randint(num_actions)] ])[0][0] obs, reward, done, info = env.step(action) # clipping the reward, noted in nature paper reward = np.clip(reward, -1.0, 1.0) # store effect of action replay_buffer.store_effect(last_stored_frame_idx, action, reward, done) # reset env if reached episode boundary if done: obs = env.reset() # update last_obs last_obs = obs # 2. Perform experience replay and train the network. # if the replay buffer contains enough samples... if (t > learning_starts and t % learning_freq == 0 and replay_buffer.can_sample(sample_size)): # sample transition batch from replay memory # done_mask = 1 if next state is end of episode obs_t, act_t, rew_t, obs_tp1, done_mask = replay_buffer.sample( sample_size) obs_t = obs_t / 255.0 act_t = torch.LongTensor(act_t).to(device) rew_t = torch.FloatTensor(rew_t).to(device) obs_tp1 = obs_tp1 / 255.0 done_mask = done_mask # input batches to networks # get the Q values for current observations (Q(s,a, theta_i)) q_values = Q(obs_t[:batch_size, :]) q_s_a = q_values.gather(1, act_t[:batch_size].unsqueeze(1)) q_s_a = q_s_a.squeeze() if restart: cur_num = 1 restart = False # get the Q values for best actions in obs_tp1 # based off frozen Q network # max(Q(s', a', theta_i_frozen)) wrt a' q_tp1_values = Q_targets[-1](obs_tp1[:batch_size, :]).detach() q_s_a_prime, _ = q_tp1_values.max(1) # if current state is end of episode, then there is no next Q value q_rhs = rew_t[:batch_size] + gamma * ( 1 - done_mask[:batch_size]) * q_s_a_prime else: cur_num += 1 num = min(MAX_NUM, cur_num) cat_obs = torch.cat((obs_t, obs_tp1), 0) qs_target_t_aa, qs_target_tp1_aa = [], [] for i in range(num, 0, -1): q_target = Q_targets[-i](cat_obs).detach() q_aa = q_target[:sample_size, :].gather( 1, act_t.unsqueeze(1)) qs_target_t_aa.append(q_aa.t()) q_next_aa, _ = q_target[sample_size:, :].max(1) qs_target_tp1_aa.append(q_next_aa.unsqueeze(0)) qs_target_t_values = torch.cat(qs_target_t_aa, 0) qs_target_tp1_values = torch.cat(qs_target_tp1_aa, 0) F_qs_target_t = torch.cat([(rew_t + gamma * (1 - done_mask) * q).unsqueeze(0) for q in qs_target_tp1_values], 0) alpha, restart = anderson.calculate(qs_target_t_values, F_qs_target_t) # get Q values from frozen network for next state and chosen action # Q(s',argmax(Q(s',a', theta_i), theta_i_frozen)) (argmax wrt a') hybird_qs_target_tp1 = beta * qs_target_t_values[:, :batch_size] + \ (1 - beta) * F_qs_target_t[:, :batch_size] q_rhs = (hybird_qs_target_tp1.t().mm(alpha)).detach() q_rhs = q_rhs.squeeze(1) # Compute Bellman error # r + gamma * Q(s',a', theta_i_frozen) - Q(s, a, theta_i) error = q_rhs - q_s_a # clip the error and flip clipped_error = -1.0 * error.clamp(-1, 1) # backwards pass optimizer.zero_grad() q_s_a.backward(clipped_error.data) # update optimizer.step() num_param_updates += 1 # update target Q network weights with current Q network weights if num_param_updates % target_update_freq == 0: Q_targets[0].load_state_dict(Q.state_dict()) Q_targets.append(Q_targets[0]) Q_targets.remove(Q_targets[0]) # 3. Log progress if t % SAVE_MODEL_EVERY_N_STEPS == 0: if save_path is not None: torch.save(Q.state_dict(), '%s/net.pth' % save_path) if t % LOG_EVERY_N_STEPS == 0: underlying_env = get_wrapper_by_name(env, "Monitor") internal_steps = underlying_env.get_total_steps() stop = (internal_steps >= max_steps) episode_rewards = underlying_env.get_episode_rewards() num_episode = len(episode_rewards) if num_episode > 0: mean_episode_reward = np.mean(episode_rewards[-100:]) best_mean_episode_reward = max(best_mean_episode_reward, mean_episode_reward) saved_scalars.append([ t, internal_steps, num_episode, mean_episode_reward, clipped_error.mean().data.cpu().numpy() ]) np.save('%s/scalars.npy' % save_path, saved_scalars) print("---------------------------------") print("Wrapped - Atari (steps) %d-%d" % (t, internal_steps)) print("episodes %d" % num_episode) print("mean episode reward %f" % mean_episode_reward) print("best mean episode reward %f" % best_mean_episode_reward) print("exploration %f" % exploration.value(t)) sys.stdout.flush() # ============ TensorBoard logging ============# info = { 'num_episodes': len(episode_rewards), 'exploration': exploration.value(t), 'mean_episode_reward_last_100': mean_episode_reward } for tag, value in info.items(): logger.scalar_summary(tag, value, t + 1) # 4. Check the stop criteria if stop: break
class NN(object): """ This is a prototype for NN wrapper in Pytorch. Please follow this coding style carefully. Args: model: Pytorch Model. train_loader (torch.dataset.DataLoader) : pytorch DataLoader for training dataset val_loader (torch.dataset.DataLoader) : pytorch DataLodaer for validation dataset epochs: opt (torch.optim) : optimizer criterion: Loss function. initial_lr (float): Initial learning rate. TODO implement using lr_find() checkpoint_save (str): Directory to save check point. model_save (str): Directory to save model. dataset: model: Pytorch Model. param_diagonstic (bool): check parameters, will be print. TODO record parameters. if_checkpoint_save (bool): save checkpoint if True print_result_epoch (bool): true if results some steps at every epochs are print. metrics : Evaluation metrics. """ def __init__(self, model=None, train_loader=None, val_loader=None, test_loader=None, if_checkpoint_save=True, penalty=None, print_result_epoch=False, print_metric_name=None, metrics=None, score_function=None, create_save_file_name=None, target_reshape=None, **kwargs): self.test_loader = test_loader self.train_loader = train_loader self.val_loader = val_loader self.model = model self.train_current_batch_data = {} self.valid_current_batch_data = {} self.if_checkpoint_save = if_checkpoint_save self.print_result_epoch = print_result_epoch self.penalty = penalty self.target_reshape = target_reshape self.metrics = metrics self.score_function = score_function self.create_save_file_name = create_save_file_name self.print_metric_name = print_metric_name self.epochs = self.model.get_epochs() self._optimizer = self.model.get_optimizer() self._criterion = self.model.get_criterion() self._lr_adjust = self.model.get_lr_scheduler() self._tensorboard_path = self.model.get_tensorboard_path() self._save_path = self.model.get_logger_path() self._logger = Logger(self._tensorboard_path) if not os.path.exists(os.path.join(self._save_path, 'train_save')): os.makedirs(os.path.join(self._save_path, 'train_save')) if not os.path.exists(os.path.join(self._save_path, 'test_save')): os.makedirs(os.path.join(self._save_path, 'test_save')) print(self._save_path, self.create_save_file_name()) self.train_checkpoint_save = os.path.join(self._save_path, 'train_save', self.create_save_file_name() + '_ckpt.path.tar') self.train_model_save = os.path.join(self._save_path,'train_save', self.create_save_file_name() + '_best.path.tar') self.test_checkpoint_save = os.path.join(self._save_path, 'test_save', self.create_save_file_name() + '_ckpt.path.tar') self.test_model_save = os.path.join(self._save_path, 'test_save', self.create_save_file_name() + '_best.path.tar') if not isinstance(self._optimizer, torch.optim.Optimizer): raise TypeError('should be an torch.optim.Optimizer type, instead of {}'.format(type(self._optimizer))) global best_val_acc, best_test_acc def train(self): print('Start training process.') self.adjust_learning_rate() best_val_acc = 0 best_test_acc = 0 for epoch in range(self.epochs): start_time = time.time() self.train_epoch() if not torch.cuda.is_available() and self.test_loader is not None: self.multi_threading_val_test() elif torch.cuda.is_available() and self.test_loader is not None: self.evaluate() self.validate_epoch() else: self.validate_epoch() info = {'train_loss': self._train_loss.avg, 'train_{}'.format(self.print_metric_name): self._train_score.avg, 'val_loss': self._valid_loss.avg, 'val_{}'.format(self.print_metric_name): self._valid_score.avg} end_time = time.time() if self.if_checkpoint_save and self.test_loader is None: is_best = self._valid_score.avg > best_val_acc if is_best: self.set_best_valid_score(self._valid_score.avg) print('>>>>>>>>>>>>>>>>>>>>>>') print('epoch {} takes {} to train'.format(epoch, start_time - end_time)) print( 'Epoch: {} train loss: {}, train {}: {}, valid loss: {}, valid {}: {}'.format(epoch, self._train_loss.avg, self._train_score.avg, self.print_metric_name, self._valid_loss.avg, self.print_metric_name, self._valid_score.avg)) print('>>>>>>>>>>>>>>>>>>>>>>') self.save_checkpoint({'epoch': epoch + 1, 'state_dict': self.model.state_dict(), 'best_val_acc': best_val_acc, 'optimizer': self._optimizer.state_dict(), }, is_best, self.train_checkpoint_save, self.train_model_save) info['best_val_{}'.format(self.print_metric_name)]= self._best_valid_score elif self.if_checkpoint_save and self.test_loader is not None: is_best = self._valid_score.avg > best_val_acc if is_best: best_val_acc = self._valid_score.avg self.set_best_valid_score(self._valid_score.avg) print('>>>>>>>>>>>>>>>>>>>>>>') print( 'Epoch: {} train loss: {}, train {}: {}, valid loss: {}, valid {}: {}'.format(epoch, self._train_loss.avg, self._train_score.avg, self.print_metric_name, self._valid_loss.avg, self.print_metric_name, self._valid_score.avg)) print('>>>>>>>>>>>>>>>>>>>>>>') self.save_checkpoint({'epoch': epoch + 1, 'state_dict': self.model.state_dict(), 'best_val_acc': best_val_acc, 'optimizer': self._optimizer.state_dict(), }, is_best, self.train_checkpoint_save, self.train_model_save) is_best_test = self._test_score.avg > best_test_acc if is_best_test: best_test_acc = self._test_score.avg self.set_best_test_score(self._test_score.avg) self.save_checkpoint({'epoch': epoch + 1, 'state_dict': self.model.state_dict(), 'best_test_acc': best_test_acc, 'optimizer': self._optimizer.state_dict(), }, is_best_test, self.test_checkpoint_save, self.test_model_save) info['best_val_{}'.format(self.print_metric_name)] = self._best_valid_score info['best_test_{}'.format(self.print_metric_name)]=self._best_test_score info['test_{}'.format(self.print_metric_name)] = self._test_score.avg for tag, value in info.items(): self._logger.scalar_summary(tag, value, epoch+1) print('Training process end.') def train_epoch(self, print_freq=100): """ Train function for every epoch. Standard for supervised learning. Args: print_freq(int): number of step to print results. The first round always print. """ losses = self.AverageMeter() percent_acc = self.AverageMeter() self.model.train() time_now = time.time() for batch_idx, (data, target) in enumerate(self.train_loader): target = target.float() if self.target_reshape is not None: target = self.target_reshape(target) if torch.cuda.is_available(): data = data.cuda() target = target.cuda() if self.score_function is None: output = self.model(data) loss = self._criterion(output, target) else: output, scores, loss = self.score_function(data, target, self._criterion, self.model) if self.penalty is not None: penalty_val = self.loss_penalty() loss += penalty_val losses.update(loss.item(), data.size(0)) if torch.cuda.is_available(): target = target.to(torch.device("cpu")) output = output.to(torch.device("cpu")) if self.score_function is not None: scores = scores.to(torch.device("cpu")) if self.score_function is None: acc = self.metrics(output, target) else: acc = self.metrics(output, target, scores) # this is design particularly for sklear.metrics.roc_auc_score # extreme value will occur when only one class presented in mini-batch if acc == 0 or acc == 1: acc = percent_acc.avg percent_acc.update(acc, data.size(0)) self._optimizer.zero_grad() loss.backward() self._optimizer.step() time_end = time.time() - time_now if batch_idx % print_freq == 0 and self.print_result_epoch: print('Training Round: {}, Time: {}'.format(batch_idx, np.round(time_end, 2))) print('Training Loss: val:{} avg:{} {}: val:{} avg:{}'.format(losses.val, losses.avg, self.print_metric_name, percent_acc.val, percent_acc.avg)) self.set_train_loss(losses) self.set_train_score(percent_acc) def validate_epoch(self, print_freq=10000): """ Validation function for every epoch. Args: print_freq(int): number of step to print results. The first round always print. """ self.model.eval() losses = self.AverageMeter() percent_acc = self.AverageMeter() with torch.no_grad(): time_now = time.time() for batch_idx, (data, target) in enumerate(self.val_loader): if self.target_reshape is not None: target = self.target_reshape(target) target = target.float() if torch.cuda.is_available(): data = data.cuda() target = target.cuda() if self.score_function is None: output = self.model(data) loss = self._criterion(output, target) else: output, scores, loss = self.score_function(data, target, self._criterion, self.model) if self.penalty is not None: penalty_val = self.loss_penalty() loss += penalty_val losses.update(loss.item(), data.size(0)) if torch.cuda.is_available(): target = target.to(torch.device("cpu")) output = output.to(torch.device("cpu")) if self.score_function is not None: scores = scores.to(torch.device("cpu")) if self.score_function is None: acc = self.metrics(output, target) else: acc = self.metrics(output, target, scores) # this is design particularly for sklear.metrics.roc_auc_score # extreme value will occur when only one class presented in mini-batch if acc == 0 or acc == 1: acc = percent_acc.avg percent_acc.update(acc, data.size(0)) time_end = time.time() - time_now if batch_idx % print_freq == 0 and self.print_result_epoch: print('Validation Round: {}, Time: {}'.format(batch_idx, np.round(time_end, 2))) print('Validation Loss: val:{} avg:{} {}: val:{} avg:{}'.format(losses.val, losses.avg, self.print_metric_name, percent_acc.val, percent_acc.avg)) self.set_valid_score(percent_acc) self.set_valid_loss(losses) def adjust_learning_rate(self): if self._lr_adjust is not None: if not isinstance(self._lr_adjust, torch.optim.lr_scheduler._LRScheduler): raise TypeError('should be inheritant learning rate scheudler.') self._lr_adjust.step() else: print('Learning rate re-schedular is not setting') """ lr = self.initial_lr - 0.0000 # reduce 10 percent every 50 epoch for param_group in opt.param_groups: param_group['lr'] = lr """ def save_checkpoint(self, state, is_best_test, checkpoint_save, model_save): """ save the best states. :param state: :param is_best: if the designated benchmark is the best in this epoch. :param ckpt_filename: the file path to save checkpoint, will be create if not exist. """ #if not os.path.exists(self.checkpoint_save): # os.mkdir(self.checkpoint_save) torch.save(state, checkpoint_save) if is_best_test: shutil.copyfile(checkpoint_save, model_save) def save_model(self): return None def resume_model(self, resume_file_path): if not os.path.exists(resume_file_path): raise ValueError('Resume file does not exist') else: print('=> loading checkpoint {}'.format(resume_file_path)) checkpoint = torch.load(resume_file_path) start_epoch = checkpoint['epoch'] self.best_val_acc = checkpoint['best_val_acc'] self.model.load_state_dict(checkpoint['state_dict']) self._optimizer.load_state_dict(checkpoint['optimizer']) print('=> loaded checkpoint {} of epoch {}'.format(resume_file_path, checkpoint['epoch'])) def evaluate(self, weights=None, print_freq=10000): """ Validation function for every epoch. Args: data_loader (torch.utils.dataset.Dataloader): Dataloader for testing. print_freq(int): number of step to print results. The first round always print. """ if weights is not None: print('Loading weights from {}'.format(weights)) self.resume_model(weights) print('Weights loaded.') self.model.eval() percent_acc = self.AverageMeter() with torch.no_grad(): time_now = time.time() for batch_idx, (data, target) in enumerate(self.test_loader): if torch.cuda.is_available(): data = data.cuda() target = target.cuda() if self.score_function is None: output = self.model(data) else: output, scores, loss = self.score_function(data, target, self._criterion, self.model) if torch.cuda.is_available(): target = target.to(torch.device("cpu")) output = output.to(torch.device("cpu")) if self.score_function is not None: scores = scores.to(torch.device("cpu")) if self.score_function is None: acc = self.metrics(output, target) else: acc = self.metrics(output, target, scores) # this is design particularly for sklear.metrics.roc_auc_score # extreme value will occur when only one class presented in mini-batch if acc == 0 or acc == 1: acc = percent_acc.avg percent_acc.update(acc, data.size(0)) time_end = time.time() - time_now print('Test {}: val:{} avg:{}'.format(self.print_metric_name, percent_acc.val, percent_acc.avg)) if not weights: print('Test evaluation is end!') self.set_test_score(percent_acc) def loss_penalty(self): l1_crit = nn.L1Loss(size_average=False) if self.penalty['type'] == 'l2': l2_penalty = 0 for param in self.model.parameters(): l2_penalty = torch.norm(param, 2) + l2_penalty l2_penalty = l2_penalty * (0.5 / self.penalty['reg']) return l2_penalty else: raise ValueError('Currently only l2 penalty are supported') def set_test_score(self, score): self._test_score = score def set_valid_score(self, score): self._valid_score = score def set_valid_loss(self, loss): self._valid_loss = loss def set_train_score(self, score): self._train_score = score def set_train_loss(self, loss): self._train_loss = loss def set_best_valid_score(self, score): self._best_valid_score = score def set_best_test_score(self, score): self._best_test_score = score def get_best_valid_score(self): try: return self._best_valid_score except Exception: print('best valid score is not defined') def get_best_test_score(self): try: return self._best_test_score except Exception: print('best test score is not defined') def multi_threading_val_test(self): """ Multi threading mode to validation and test at the same time. """ val_thread = threading.Thread(target=self.validate_epoch) val_thread.start() test_tread = threading.Thread(target=self.evaluate) test_tread.start() val_thread.join() test_tread.join() class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count
class TrainModel(object): def __init__(self): self.logger = Logger('logs/') # model self.model = None self.optimizer = None self.lr_scheduler = None # data self.train_loader = None self.val_loader = None self.train_data = None self.val_data = None def val_step(self): """ Validation step """ cum_loss = 0 predicts = [] truths = [] self.model.eval() for inputs, masks, target in tqdm(self.val_loader, total=len(self.val_loader), ascii=True, desc='validation'): inputs, masks, target = inputs.to(device), masks.to(device), target.to(device) with torch.set_grad_enabled(False): out = self.model(inputs) loss1 = nn.BCEWithLogitsLoss()(out, masks) loss2 = lovasz_softmax(F.softmax(out, dim=1), target) # tune loss = loss1 + loss2 predicts.append(F.sigmoid(out).detach().cpu().numpy()) truths.append(masks.detach().cpu().numpy()) cum_loss += loss.item() * inputs.size(0) gc.collect() start = time.time() predicts = np.concatenate(predicts).squeeze() truths = np.concatenate(truths).squeeze() mean_dice = dice_channel_torch(predicts, truths, 0.5) val_loss = cum_loss / self.val_data.__len__() print(f"Val calculated: {(time.time() - start):.3f}s") gc.collect() return val_loss, mean_dice def train_step(self): """ Training step """ cum_loss = 0 self.model.train() for inputs, masks, target in tqdm(self.train_loader, total=len(self.train_loader), ascii=True, desc='train'): inputs, masks, target = inputs.to(device), masks.to(device), target.to(device) self.optimizer.zero_grad() with torch.set_grad_enabled(True): out = self.model(inputs) loss1 = nn.BCEWithLogitsLoss()(out, masks) loss2 = lovasz_softmax(F.softmax(out, dim=1), target) # tune loss = loss1 + loss2 loss.backward() self.optimizer.step() gc.collect() cum_loss += loss.item() * inputs.size(0) epoch_loss = cum_loss / self.train_data.__len__() gc.collect() return epoch_loss def logger_step(self, cur_epoch, losses_train, losses_val, dice): """ Log information """ print(f"[Epoch {cur_epoch}] training loss: {losses_train[-1]:.6f} | val_loss: {losses_val[-1]:.6f} | " f"val_dice: {dice:.6f}") # print(f"Learning rate: {self.lr_scheduler.get_lr()[0]:.6f}") # 1. Log scalar values (scalar summary) info = {'loss': losses_train[-1], 'val_loss': losses_val[-1], 'dice': dice} for tag, value in info.items(): self.logger.scalar_summary(tag, value, cur_epoch + 1) # 2. Log values and gradients of the parameters (histogram summary) for tag, value in self.model.named_parameters(): tag = tag.replace('.', '/') self.logger.histo_summary(tag, value.data.cpu().numpy(), cur_epoch + 1) self.logger.histo_summary(tag + '/grad', value.grad.data.cpu().numpy(), cur_epoch + 1) return True def main(self): """ Main training loop """ # Get Model self.model = smp.Unet(args.model, classes=4, encoder_weights='imagenet') self.model = torch.nn.Sequential(*(list(self.model.children())[:-1])) self.model.to(device) self.model.state_dict(torch.load('output/weights/resnet34_f0_s3.pth')) scheduler_step = args.epoch // args.snapshot num_train = len(os.listdir('input/severstal-steel-defect-detection/train_images')) # num_train = 1000 indices = list(range(num_train)) if args.num_fold > 1: kf = KFold(n_splits=args.num_fold, random_state=42, shuffle=True) train_idx = [] valid_idx = [] for t, v in kf.split(indices): train_idx.append(t) valid_idx.append(v) elif args.num_fold == 1: train_idx, valid_idx, _, _ = train_test_split(indices, indices, test_size=0.2, random_state=42) train_idx, valid_idx = [train_idx], [valid_idx] else: raise Exception('Invalid number of args.num_fold') for fold in range(args.num_fold): print(f'************************' f'**** [FOLD: {fold}] ****' f'************************') self.train_data = getDatabase(mode='train', image_idx=train_idx[fold]) self.train_loader = DataLoader(self.train_data, shuffle=RandomSampler(self.train_data), batch_size=args.batch_size, num_workers=6, pin_memory=True) self.val_data = getDatabase(mode='val', image_idx=valid_idx[fold]) self.val_loader = DataLoader(self.val_data, shuffle=False, batch_size=args.batch_size, num_workers=6, pin_memory=True) num_snapshot = 0 best_acc = 0 # Setup optimizer self.optimizer = torch.optim.SGD(self.model.parameters(), lr=args.max_lr, momentum=args.momentum, weight_decay=args.weight_decay) # self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, # scheduler_step, args.min_lr) self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=6, verbose=True, ) # Service variables losses_train = [] # save training losses losses_val = [] # save validation losses for epoch in range(args.epoch): train_loss = self.train_step() # train_loss = 1 val_loss, accuracy = self.val_step() # self.lr_scheduler.step() # for CosineAnnealingLR self.lr_scheduler.step(val_loss) # for ReduceLROnPlateau losses_train.append(train_loss) losses_val.append(val_loss) self.logger_step(epoch, losses_train, losses_val, accuracy) # scheduler checkpoint if accuracy >= best_acc: best_acc = accuracy best_param = self.model.state_dict() torch.save(best_param, args.save_weight + args.weight_name + '_lrPlateau' + '.pth') if (epoch + 1) % scheduler_step == 0: torch.save(best_param, args.save_weight + args.weight_name + '_f' + str(fold) + '_s' + str(num_snapshot) + '.pth') self.optimizer = torch.optim.SGD(self.model.parameters(), lr=args.max_lr, momentum=args.momentum, weight_decay=args.weight_decay) self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, scheduler_step, args.min_lr) num_snapshot += 1 best_acc = 0 return True