示例#1
0
文件: raa_dqn.py 项目: yyf17/raa-drl
def dqn_learning(env,
                 q_func,
                 optimizer_spec,
                 exploration=LinearSchedule(1000000, 0.1),
                 max_steps=20e6,
                 replay_buffer_size=1000000,
                 batch_size=32,
                 sample_size=128,
                 gamma=0.99,
                 beta=0.05,
                 reg_scale=0.1,
                 use_restart=True,
                 learning_starts=50000,
                 learning_freq=4,
                 frame_history_len=4,
                 target_update_freq=2000,
                 save_path=None):
    """Run Deep Q-learning algorithm with regularized anderson acceleration.
    You can specify your own convnet using q_func.
    All schedules are w.r.t. total number of steps taken in the environment.
    Parameters
    ----------
    env: gym.Env
        gym environment to train on.
    q_func: function
        Model to use for computing the q function.
    optimizer_spec: OptimizerSpec
        Specifying the constructor and kwargs, as well as learning rate schedule
        for the optimizer
    exploration: rl_algs.deepq.utils.schedules.Schedule
        schedule for probability of chosing random action.
    max_steps: float
        Maximal steps.
    replay_buffer_size: int
        How many memories to store in the replay buffer.
    batch_size: int
        How many transitions to sample each time experience is replayed.
    gamma: float
        Discount Factor
    learning_starts: int
        After how many environment steps to start replaying experiences
    learning_freq: int
        How many steps of environment to take between every experience replay
    frame_history_len: int
        How many past frames to include as input to the model.
    target_update_freq: int
        How many experience replay rounds (not steps!) to perform between
        each update to the target Q network
    grad_norm_clipping: float or None
        If not None gradients' norms are clipped to this value.
    """
    assert type(env.observation_space) == gym.spaces.Box
    assert type(env.action_space) == gym.spaces.Discrete

    # Set the logger
    logger = Logger(save_path)

    ###############
    # BUILD MODEL #
    ###############

    if len(env.observation_space.shape) == 1:
        # This means we are running on low-dimensional observations (e.g. RAM)
        input_shape = env.observation_space.shape
        in_channels = input_shape[0]
    else:
        img_h, img_w, img_c = env.observation_space.shape
        input_shape = (img_h, img_w, frame_history_len * img_c)
        in_channels = input_shape[2]
    num_actions = env.action_space.n

    # define Q target and Q
    Q = q_func(in_channels, num_actions).to(device)
    Q_targets = []
    MAX_NUM = 5
    for i in range(MAX_NUM):
        Q_targets.append(q_func(in_channels, num_actions).to(device))

    # initialize anderson
    anderson = RAA(MAX_NUM, use_restart, reg_scale)

    # initialize optimizer
    optimizer = optimizer_spec.constructor(Q.parameters(),
                                           **optimizer_spec.kwargs)

    # create replay buffer
    replay_buffer = ReplayBuffer(replay_buffer_size, frame_history_len)

    ######

    ###############
    # RUN ENV     #
    ###############
    num_param_updates = 0
    mean_episode_reward = -float('nan')
    best_mean_episode_reward = -float('inf')
    last_obs = env.reset()
    LOG_EVERY_N_STEPS = 10000
    SAVE_MODEL_EVERY_N_STEPS = 100000
    saved_scalars = []
    stop = False
    restart = True
    cur_num = 1
    clipped_error = torch.FloatTensor([0]).to(device)

    for t in itertools.count():
        # 1. Step the env and store the transition
        # store last frame, returned idx used later
        last_stored_frame_idx = replay_buffer.store_frame(last_obs)

        # get observations to input to Q network (need to append prev frames)
        observations = replay_buffer.encode_recent_observation()  # torch

        # before learning starts, choose actions randomly
        if t < learning_starts:
            action = np.random.randint(num_actions)
        else:
            # epsilon greedy exploration
            sample = random.random()
            threshold = exploration.value(t)
            if sample > threshold:
                obs = observations.unsqueeze(0) / 255.0
                with torch.no_grad():
                    q_value_all_actions = Q(obs)
                action = (q_value_all_actions.data.max(1)[1])[0]
            else:
                action = torch.IntTensor([[np.random.randint(num_actions)]
                                          ])[0][0]

        obs, reward, done, info = env.step(action)

        # clipping the reward, noted in nature paper
        reward = np.clip(reward, -1.0, 1.0)

        # store effect of action
        replay_buffer.store_effect(last_stored_frame_idx, action, reward, done)

        # reset env if reached episode boundary
        if done:
            obs = env.reset()

        # update last_obs
        last_obs = obs

        # 2. Perform experience replay and train the network.
        # if the replay buffer contains enough samples...
        if (t > learning_starts and t % learning_freq == 0
                and replay_buffer.can_sample(sample_size)):

            # sample transition batch from replay memory
            # done_mask = 1 if next state is end of episode
            obs_t, act_t, rew_t, obs_tp1, done_mask = replay_buffer.sample(
                sample_size)
            obs_t = obs_t / 255.0
            act_t = torch.LongTensor(act_t).to(device)
            rew_t = torch.FloatTensor(rew_t).to(device)
            obs_tp1 = obs_tp1 / 255.0
            done_mask = done_mask

            # input batches to networks
            # get the Q values for current observations (Q(s,a, theta_i))
            q_values = Q(obs_t[:batch_size, :])
            q_s_a = q_values.gather(1, act_t[:batch_size].unsqueeze(1))
            q_s_a = q_s_a.squeeze()

            if restart:
                cur_num = 1
                restart = False

                # get the Q values for best actions in obs_tp1
                # based off frozen Q network
                # max(Q(s', a', theta_i_frozen)) wrt a'
                q_tp1_values = Q_targets[-1](obs_tp1[:batch_size, :]).detach()
                q_s_a_prime, _ = q_tp1_values.max(1)

                # if current state is end of episode, then there is no next Q value
                q_rhs = rew_t[:batch_size] + gamma * (
                    1 - done_mask[:batch_size]) * q_s_a_prime
            else:
                cur_num += 1
                num = min(MAX_NUM, cur_num)

                cat_obs = torch.cat((obs_t, obs_tp1), 0)

                qs_target_t_aa, qs_target_tp1_aa = [], []
                for i in range(num, 0, -1):
                    q_target = Q_targets[-i](cat_obs).detach()

                    q_aa = q_target[:sample_size, :].gather(
                        1, act_t.unsqueeze(1))
                    qs_target_t_aa.append(q_aa.t())

                    q_next_aa, _ = q_target[sample_size:, :].max(1)
                    qs_target_tp1_aa.append(q_next_aa.unsqueeze(0))

                qs_target_t_values = torch.cat(qs_target_t_aa, 0)
                qs_target_tp1_values = torch.cat(qs_target_tp1_aa, 0)

                F_qs_target_t = torch.cat([(rew_t + gamma *
                                            (1 - done_mask) * q).unsqueeze(0)
                                           for q in qs_target_tp1_values], 0)

                alpha, restart = anderson.calculate(qs_target_t_values,
                                                    F_qs_target_t)
                # get Q values from frozen network for next state and chosen action
                # Q(s',argmax(Q(s',a', theta_i), theta_i_frozen)) (argmax wrt a')
                hybird_qs_target_tp1 = beta * qs_target_t_values[:, :batch_size] + \
                                       (1 - beta) * F_qs_target_t[:, :batch_size]
                q_rhs = (hybird_qs_target_tp1.t().mm(alpha)).detach()
                q_rhs = q_rhs.squeeze(1)

            # Compute Bellman error
            # r + gamma * Q(s',a', theta_i_frozen) - Q(s, a, theta_i)
            error = q_rhs - q_s_a

            # clip the error and flip
            clipped_error = -1.0 * error.clamp(-1, 1)

            # backwards pass
            optimizer.zero_grad()
            q_s_a.backward(clipped_error.data)

            # update
            optimizer.step()
            num_param_updates += 1

            # update target Q network weights with current Q network weights
            if num_param_updates % target_update_freq == 0:
                Q_targets[0].load_state_dict(Q.state_dict())
                Q_targets.append(Q_targets[0])
                Q_targets.remove(Q_targets[0])

        # 3. Log progress
        if t % SAVE_MODEL_EVERY_N_STEPS == 0:
            if save_path is not None:
                torch.save(Q.state_dict(), '%s/net.pth' % save_path)

        if t % LOG_EVERY_N_STEPS == 0:
            underlying_env = get_wrapper_by_name(env, "Monitor")
            internal_steps = underlying_env.get_total_steps()
            stop = (internal_steps >= max_steps)
            episode_rewards = underlying_env.get_episode_rewards()
            num_episode = len(episode_rewards)

            if num_episode > 0:
                mean_episode_reward = np.mean(episode_rewards[-100:])
                best_mean_episode_reward = max(best_mean_episode_reward,
                                               mean_episode_reward)

                saved_scalars.append([
                    t, internal_steps, num_episode, mean_episode_reward,
                    clipped_error.mean().data.cpu().numpy()
                ])
                np.save('%s/scalars.npy' % save_path, saved_scalars)

            print("---------------------------------")
            print("Wrapped - Atari (steps) %d-%d" % (t, internal_steps))
            print("episodes %d" % num_episode)
            print("mean episode reward %f" % mean_episode_reward)
            print("best mean episode reward %f" % best_mean_episode_reward)
            print("exploration %f" % exploration.value(t))
            sys.stdout.flush()

            # ============ TensorBoard logging ============#
            info = {
                'num_episodes': len(episode_rewards),
                'exploration': exploration.value(t),
                'mean_episode_reward_last_100': mean_episode_reward
            }
            for tag, value in info.items():
                logger.scalar_summary(tag, value, t + 1)

        # 4. Check the stop criteria
        if stop:
            break
示例#2
0
class NN(object):
    """
    This is a prototype for NN wrapper in Pytorch.
    Please follow this coding style carefully.
    Args:
        model:
            Pytorch Model.
        train_loader (torch.dataset.DataLoader) :
            pytorch DataLoader for training dataset
        val_loader (torch.dataset.DataLoader) :
            pytorch DataLodaer for validation dataset
        epochs:

        opt (torch.optim) :
            optimizer
        criterion:
            Loss function.
        initial_lr (float):
            Initial learning rate. TODO implement using lr_find()
        checkpoint_save (str):
            Directory to save check point.
        model_save (str):
            Directory to save model.
        dataset:

        model:
            Pytorch Model.
        param_diagonstic (bool):
            check parameters, will be print. TODO record parameters.
        if_checkpoint_save (bool):
            save checkpoint if True
        print_result_epoch (bool):
            true if results some steps at every epochs are print.
        metrics :
            Evaluation metrics.
    """

    def __init__(self, model=None, train_loader=None, val_loader=None,
                 test_loader=None,
                 if_checkpoint_save=True,
                penalty=None,
                 print_result_epoch=False,
                 print_metric_name=None,
                 metrics=None,
                 score_function=None,
                 create_save_file_name=None,
                 target_reshape=None, **kwargs):
        self.test_loader = test_loader
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.model = model
        self.train_current_batch_data = {}
        self.valid_current_batch_data = {}
        self.if_checkpoint_save = if_checkpoint_save
        self.print_result_epoch = print_result_epoch
        self.penalty = penalty
        self.target_reshape = target_reshape
        self.metrics = metrics
        self.score_function = score_function
        self.create_save_file_name = create_save_file_name
        self.print_metric_name = print_metric_name


        self.epochs = self.model.get_epochs()
        self._optimizer = self.model.get_optimizer()
        self._criterion = self.model.get_criterion()
        self._lr_adjust = self.model.get_lr_scheduler()
        self._tensorboard_path = self.model.get_tensorboard_path()
        self._save_path = self.model.get_logger_path()
        self._logger = Logger(self._tensorboard_path)

        if not os.path.exists(os.path.join(self._save_path, 'train_save')):
            os.makedirs(os.path.join(self._save_path, 'train_save'))
        if not os.path.exists(os.path.join(self._save_path, 'test_save')):
            os.makedirs(os.path.join(self._save_path, 'test_save'))
        print(self._save_path, self.create_save_file_name())

        self.train_checkpoint_save = os.path.join(self._save_path, 'train_save', self.create_save_file_name() + '_ckpt.path.tar')
        self.train_model_save = os.path.join(self._save_path,'train_save', self.create_save_file_name() + '_best.path.tar')

        self.test_checkpoint_save = os.path.join(self._save_path, 'test_save',
                                                 self.create_save_file_name() + '_ckpt.path.tar')
        self.test_model_save = os.path.join(self._save_path, 'test_save',
                                             self.create_save_file_name() + '_best.path.tar')

        if not isinstance(self._optimizer, torch.optim.Optimizer):
            raise TypeError('should be an torch.optim.Optimizer type, instead of {}'.format(type(self._optimizer)))

        global best_val_acc, best_test_acc

    def train(self):
        print('Start training process.')
        self.adjust_learning_rate()


        best_val_acc = 0
        best_test_acc = 0

        for epoch in range(self.epochs):
            start_time = time.time()
            self.train_epoch()
            if not torch.cuda.is_available() and self.test_loader is not None:
                self.multi_threading_val_test()
            elif torch.cuda.is_available() and self.test_loader is not None:
                self.evaluate()
                self.validate_epoch()
            else:
                self.validate_epoch()
            info = {'train_loss': self._train_loss.avg, 'train_{}'.format(self.print_metric_name): self._train_score.avg,
                    'val_loss': self._valid_loss.avg, 'val_{}'.format(self.print_metric_name): self._valid_score.avg}

            end_time = time.time()
            if self.if_checkpoint_save and self.test_loader is None:
                is_best = self._valid_score.avg > best_val_acc
                if is_best:
                    self.set_best_valid_score(self._valid_score.avg)
                print('>>>>>>>>>>>>>>>>>>>>>>')
                print('epoch {} takes {} to train'.format(epoch, start_time - end_time))
                print(
                    'Epoch: {} train loss: {}, train {}: {}, valid loss: {}, valid {}: {}'.format(epoch, self._train_loss.avg,
                                                                                                  self._train_score.avg,
                                                                                                  self.print_metric_name,
                                                                                                  self._valid_loss.avg,
                                                                                                  self.print_metric_name,
                                                                                                  self._valid_score.avg))
                print('>>>>>>>>>>>>>>>>>>>>>>')
                self.save_checkpoint({'epoch': epoch + 1,
                                      'state_dict': self.model.state_dict(),
                                      'best_val_acc': best_val_acc,
                                      'optimizer': self._optimizer.state_dict(), }, is_best,
                                     self.train_checkpoint_save, self.train_model_save)
                info['best_val_{}'.format(self.print_metric_name)]= self._best_valid_score

            elif self.if_checkpoint_save and self.test_loader is not None:
                is_best = self._valid_score.avg > best_val_acc
                if is_best:
                    best_val_acc = self._valid_score.avg
                    self.set_best_valid_score(self._valid_score.avg)
                print('>>>>>>>>>>>>>>>>>>>>>>')
                print(
                    'Epoch: {} train loss: {}, train {}: {}, valid loss: {}, valid {}: {}'.format(epoch,
                                                                                                  self._train_loss.avg,
                                                                                                  self._train_score.avg,
                                                                                                  self.print_metric_name,
                                                                                                  self._valid_loss.avg,
                                                                                                  self.print_metric_name,
                                                                                                  self._valid_score.avg))
                print('>>>>>>>>>>>>>>>>>>>>>>')
                self.save_checkpoint({'epoch': epoch + 1,
                                      'state_dict': self.model.state_dict(),
                                      'best_val_acc': best_val_acc,
                                      'optimizer': self._optimizer.state_dict(), }, is_best,
                                     self.train_checkpoint_save, self.train_model_save)

                is_best_test = self._test_score.avg > best_test_acc
                if is_best_test:
                    best_test_acc = self._test_score.avg
                    self.set_best_test_score(self._test_score.avg)
                self.save_checkpoint({'epoch': epoch + 1,
                                      'state_dict': self.model.state_dict(),
                                      'best_test_acc': best_test_acc,
                                      'optimizer': self._optimizer.state_dict(), }, is_best_test,
                                     self.test_checkpoint_save, self.test_model_save)

                info['best_val_{}'.format(self.print_metric_name)] = self._best_valid_score
                info['best_test_{}'.format(self.print_metric_name)]=self._best_test_score
                info['test_{}'.format(self.print_metric_name)] = self._test_score.avg

            for tag, value in info.items():
                self._logger.scalar_summary(tag, value, epoch+1)

        print('Training process end.')

    def train_epoch(self, print_freq=100):
        """
        Train function for every epoch. Standard for supervised learning.
        Args:
            print_freq(int): number of step to print results. The first round always print.
        """
        losses = self.AverageMeter()
        percent_acc = self.AverageMeter()
        self.model.train()
        time_now = time.time()

        for batch_idx, (data, target) in enumerate(self.train_loader):
            target = target.float()

            if self.target_reshape is not None:
                target = self.target_reshape(target)
            if torch.cuda.is_available():
                data = data.cuda()
                target = target.cuda()

            if self.score_function is None:
                output = self.model(data)
                loss = self._criterion(output, target)
            else:
                output, scores, loss = self.score_function(data, target, self._criterion, self.model)

            if self.penalty is not None:
                penalty_val = self.loss_penalty()
                loss += penalty_val

            losses.update(loss.item(), data.size(0))

            if torch.cuda.is_available():
                target = target.to(torch.device("cpu"))
                output = output.to(torch.device("cpu"))
                if self.score_function is not None:
                    scores = scores.to(torch.device("cpu"))

            if self.score_function is None:
                acc = self.metrics(output, target)
            else:
                acc = self.metrics(output, target, scores)

            # this is design particularly for sklear.metrics.roc_auc_score
            # extreme value will occur when only one class presented in mini-batch
            if acc == 0 or acc == 1:
                acc = percent_acc.avg

            percent_acc.update(acc, data.size(0))

            self._optimizer.zero_grad()
            loss.backward()
            self._optimizer.step()

            time_end = time.time() - time_now
            if batch_idx % print_freq == 0 and self.print_result_epoch:
                print('Training Round: {}, Time: {}'.format(batch_idx,
                                            np.round(time_end, 2)))
                print('Training Loss: val:{} avg:{} {}: val:{} avg:{}'.format(losses.val,
                                                              losses.avg,
                                                              self.print_metric_name,
                                                              percent_acc.val, percent_acc.avg))
        self.set_train_loss(losses)
        self.set_train_score(percent_acc)

    def validate_epoch(self, print_freq=10000):
        """
        Validation function for every epoch.
        Args:
            print_freq(int): number of step to print results. The first round always print.
        """
        self.model.eval()
        losses = self.AverageMeter()
        percent_acc = self.AverageMeter()

        with torch.no_grad():
            time_now = time.time()
            for batch_idx, (data, target) in enumerate(self.val_loader):
                if self.target_reshape is not None:
                    target = self.target_reshape(target)
                target = target.float()
                if torch.cuda.is_available():
                    data = data.cuda()
                    target = target.cuda()

                if self.score_function is None:
                    output = self.model(data)
                    loss = self._criterion(output, target)
                else:
                    output, scores, loss = self.score_function(data, target, self._criterion, self.model)

                if self.penalty is not None:
                    penalty_val = self.loss_penalty()
                    loss += penalty_val

                losses.update(loss.item(), data.size(0))

                if torch.cuda.is_available():
                    target = target.to(torch.device("cpu"))
                    output = output.to(torch.device("cpu"))
                    if self.score_function is not None:
                        scores = scores.to(torch.device("cpu"))

                if self.score_function is None:
                    acc = self.metrics(output, target)
                else:
                    acc = self.metrics(output, target, scores)

                # this is design particularly for sklear.metrics.roc_auc_score
                # extreme value will occur when only one class presented in mini-batch
                if acc == 0 or acc == 1:
                    acc = percent_acc.avg


                percent_acc.update(acc, data.size(0))
                time_end = time.time() - time_now
                if batch_idx % print_freq == 0 and self.print_result_epoch:
                    print('Validation Round: {}, Time: {}'.format(batch_idx, np.round(time_end, 2)))
                    print('Validation Loss: val:{} avg:{} {}: val:{} avg:{}'.format(losses.val, losses.avg, self.print_metric_name,
                                                                                 percent_acc.val, percent_acc.avg))
        self.set_valid_score(percent_acc)
        self.set_valid_loss(losses)

    def adjust_learning_rate(self):
        if self._lr_adjust is not None:
            if not isinstance(self._lr_adjust, torch.optim.lr_scheduler._LRScheduler):
                raise TypeError('should be inheritant learning rate scheudler.')
            self._lr_adjust.step()
        else:
            print('Learning rate re-schedular is not setting')
        """
        lr = self.initial_lr - 0.0000  # reduce 10 percent every 50 epoch
        for param_group in opt.param_groups:
            param_group['lr'] = lr
        """

    def save_checkpoint(self, state, is_best_test, checkpoint_save, model_save):
        """
        save the best states.
        :param state:
        :param is_best: if the designated benchmark is the best in this epoch.
        :param ckpt_filename: the file path to save checkpoint, will be create if not exist.
        """
        #if not os.path.exists(self.checkpoint_save):
        #    os.mkdir(self.checkpoint_save)
        torch.save(state, checkpoint_save)
        if is_best_test:
            shutil.copyfile(checkpoint_save, model_save)

    def save_model(self):
        return None

    def resume_model(self, resume_file_path):
        if not os.path.exists(resume_file_path):
            raise ValueError('Resume file does not exist')
        else:
            print('=> loading checkpoint {}'.format(resume_file_path))
            checkpoint = torch.load(resume_file_path)
            start_epoch = checkpoint['epoch']
            self.best_val_acc = checkpoint['best_val_acc']
            self.model.load_state_dict(checkpoint['state_dict'])
            self._optimizer.load_state_dict(checkpoint['optimizer'])
            print('=> loaded checkpoint {} of epoch {}'.format(resume_file_path,
                checkpoint['epoch']))

    def evaluate(self, weights=None, print_freq=10000):
        """
        Validation function for every epoch.
        Args:
            data_loader (torch.utils.dataset.Dataloader): Dataloader for testing.
            print_freq(int): number of step to print results. The first round always print.
        """
        if weights is not None:
            print('Loading weights from {}'.format(weights))
            self.resume_model(weights)
            print('Weights loaded.')
        self.model.eval()
        percent_acc = self.AverageMeter()

        with torch.no_grad():
            time_now = time.time()
            for batch_idx, (data, target) in enumerate(self.test_loader):
                if torch.cuda.is_available():
                    data = data.cuda()
                    target = target.cuda()

                if self.score_function is None:
                    output = self.model(data)
                else:
                    output, scores, loss = self.score_function(data, target, self._criterion, self.model)

                if torch.cuda.is_available():
                    target = target.to(torch.device("cpu"))
                    output = output.to(torch.device("cpu"))
                    if self.score_function is not None:
                        scores = scores.to(torch.device("cpu"))

                if self.score_function is None:
                    acc = self.metrics(output, target)
                else:
                    acc = self.metrics(output, target, scores)

                # this is design particularly for sklear.metrics.roc_auc_score
                # extreme value will occur when only one class presented in mini-batch
                if acc == 0 or acc == 1:
                    acc = percent_acc.avg

                percent_acc.update(acc, data.size(0))
                time_end = time.time() - time_now
        print('Test {}: val:{} avg:{}'.format(self.print_metric_name, percent_acc.val, percent_acc.avg))
        if not weights:
            print('Test evaluation is end!')
        self.set_test_score(percent_acc)

    def loss_penalty(self):
        l1_crit = nn.L1Loss(size_average=False)
        if self.penalty['type'] == 'l2':
            l2_penalty = 0

            for param in self.model.parameters():
                l2_penalty = torch.norm(param, 2) + l2_penalty
            l2_penalty = l2_penalty * (0.5 / self.penalty['reg'])
            return l2_penalty
        else:
            raise ValueError('Currently only l2 penalty are supported')

    def set_test_score(self, score):
        self._test_score = score

    def set_valid_score(self, score):
        self._valid_score = score

    def set_valid_loss(self, loss):
        self._valid_loss = loss

    def set_train_score(self, score):
        self._train_score = score

    def set_train_loss(self, loss):
        self._train_loss = loss

    def set_best_valid_score(self, score):
        self._best_valid_score = score

    def set_best_test_score(self, score):
        self._best_test_score = score

    def get_best_valid_score(self):
        try:
            return self._best_valid_score
        except Exception:
            print('best valid score is not defined')

    def get_best_test_score(self):
        try:
            return self._best_test_score
        except Exception:
            print('best test score is not defined')

    def multi_threading_val_test(self):
        """
        Multi threading mode to validation and test at the same time.
        """
        val_thread = threading.Thread(target=self.validate_epoch)
        val_thread.start()
        test_tread = threading.Thread(target=self.evaluate)
        test_tread.start()

        val_thread.join()
        test_tread.join()

    class AverageMeter(object):
        """Computes and stores the average and current value"""

        def __init__(self):
            self.reset()

        def reset(self):
            self.val = 0
            self.avg = 0
            self.sum = 0
            self.count = 0

        def update(self, val, n=1):
            self.val = val
            self.sum += val * n
            self.count += n
            self.avg = self.sum / self.count
示例#3
0
class TrainModel(object):
    def __init__(self):
        self.logger = Logger('logs/')

        # model
        self.model = None
        self.optimizer = None
        self.lr_scheduler = None

        # data
        self.train_loader = None
        self.val_loader = None
        self.train_data = None
        self.val_data = None

    def val_step(self):
        """ Validation step """
        cum_loss = 0
        predicts = []
        truths = []

        self.model.eval()
        for inputs, masks, target in tqdm(self.val_loader, total=len(self.val_loader), ascii=True, desc='validation'):
            inputs, masks, target = inputs.to(device), masks.to(device), target.to(device)
            with torch.set_grad_enabled(False):
                out = self.model(inputs)
                loss1 = nn.BCEWithLogitsLoss()(out, masks)
                loss2 = lovasz_softmax(F.softmax(out, dim=1), target)  # tune
                loss = loss1 + loss2

            predicts.append(F.sigmoid(out).detach().cpu().numpy())
            truths.append(masks.detach().cpu().numpy())
            cum_loss += loss.item() * inputs.size(0)
            gc.collect()

        start = time.time()
        predicts = np.concatenate(predicts).squeeze()
        truths = np.concatenate(truths).squeeze()
        mean_dice = dice_channel_torch(predicts, truths, 0.5)
        val_loss = cum_loss / self.val_data.__len__()
        print(f"Val calculated: {(time.time() - start):.3f}s")
        gc.collect()
        return val_loss, mean_dice

    def train_step(self):
        """ Training step """
        cum_loss = 0
        self.model.train()
        for inputs, masks, target in tqdm(self.train_loader, total=len(self.train_loader), ascii=True, desc='train'):
            inputs, masks, target = inputs.to(device), masks.to(device), target.to(device)
            self.optimizer.zero_grad()

            with torch.set_grad_enabled(True):
                out = self.model(inputs)
                loss1 = nn.BCEWithLogitsLoss()(out, masks)
                loss2 = lovasz_softmax(F.softmax(out, dim=1), target)  # tune
                loss = loss1 + loss2

                loss.backward()
                self.optimizer.step()
                gc.collect()

            cum_loss += loss.item() * inputs.size(0)

        epoch_loss = cum_loss / self.train_data.__len__()
        gc.collect()
        return epoch_loss

    def logger_step(self, cur_epoch, losses_train, losses_val, dice):
        """ Log information """
        print(f"[Epoch {cur_epoch}] training loss: {losses_train[-1]:.6f} | val_loss: {losses_val[-1]:.6f} | "
              f"val_dice: {dice:.6f}")
        # print(f"Learning rate: {self.lr_scheduler.get_lr()[0]:.6f}")

        # 1. Log scalar values (scalar summary)
        info = {'loss': losses_train[-1],
                'val_loss': losses_val[-1],
                'dice': dice}

        for tag, value in info.items():
            self.logger.scalar_summary(tag, value, cur_epoch + 1)

        # 2. Log values and gradients of the parameters (histogram summary)
        for tag, value in self.model.named_parameters():
            tag = tag.replace('.', '/')
            self.logger.histo_summary(tag, value.data.cpu().numpy(), cur_epoch + 1)
            self.logger.histo_summary(tag + '/grad', value.grad.data.cpu().numpy(), cur_epoch + 1)

        return True

    def main(self):
        """ Main training loop """
        # Get Model
        self.model = smp.Unet(args.model, classes=4, encoder_weights='imagenet')
        self.model = torch.nn.Sequential(*(list(self.model.children())[:-1]))
        self.model.to(device)
        self.model.state_dict(torch.load('output/weights/resnet34_f0_s3.pth'))
        scheduler_step = args.epoch // args.snapshot

        num_train = len(os.listdir('input/severstal-steel-defect-detection/train_images'))
        # num_train = 1000
        indices = list(range(num_train))

        if args.num_fold > 1:
            kf = KFold(n_splits=args.num_fold, random_state=42, shuffle=True)
            train_idx = []
            valid_idx = []
            for t, v in kf.split(indices):
                train_idx.append(t)
                valid_idx.append(v)
        elif args.num_fold == 1:
            train_idx, valid_idx, _, _ = train_test_split(indices, indices, test_size=0.2, random_state=42)
            train_idx, valid_idx = [train_idx], [valid_idx]
        else:
            raise Exception('Invalid number of args.num_fold')

        for fold in range(args.num_fold):
            print(f'************************'
                  f'**** [FOLD: {fold}] ****'
                  f'************************')
            self.train_data = getDatabase(mode='train', image_idx=train_idx[fold])
            self.train_loader = DataLoader(self.train_data,
                                           shuffle=RandomSampler(self.train_data),
                                           batch_size=args.batch_size,
                                           num_workers=6,
                                           pin_memory=True)
            self.val_data = getDatabase(mode='val', image_idx=valid_idx[fold])
            self.val_loader = DataLoader(self.val_data,
                                         shuffle=False,
                                         batch_size=args.batch_size,
                                         num_workers=6,
                                         pin_memory=True)

            num_snapshot = 0
            best_acc = 0
            # Setup optimizer
            self.optimizer = torch.optim.SGD(self.model.parameters(), lr=args.max_lr, momentum=args.momentum,
                                             weight_decay=args.weight_decay)
            # self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer,
            #                                                                scheduler_step, args.min_lr)
            self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=6,
                                                                           verbose=True,
                                                                           )

            # Service variables
            losses_train = []  # save training losses
            losses_val = []  # save validation losses

            for epoch in range(args.epoch):
                train_loss = self.train_step()
                # train_loss = 1
                val_loss, accuracy = self.val_step()
                # self.lr_scheduler.step()  # for CosineAnnealingLR
                self.lr_scheduler.step(val_loss)  # for ReduceLROnPlateau

                losses_train.append(train_loss)
                losses_val.append(val_loss)

                self.logger_step(epoch, losses_train, losses_val, accuracy)

                # scheduler checkpoint
                if accuracy >= best_acc:
                    best_acc = accuracy
                    best_param = self.model.state_dict()
                    torch.save(best_param, args.save_weight + args.weight_name +
                               '_lrPlateau' + '.pth')

                if (epoch + 1) % scheduler_step == 0:
                    torch.save(best_param, args.save_weight + args.weight_name +
                               '_f' + str(fold) + '_s' + str(num_snapshot) + '.pth')

                    self.optimizer = torch.optim.SGD(self.model.parameters(),
                                                     lr=args.max_lr,
                                                     momentum=args.momentum,
                                                     weight_decay=args.weight_decay)
                    self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer,
                                                                                   scheduler_step,
                                                                                   args.min_lr)
                    num_snapshot += 1
                    best_acc = 0

        return True