Exemple #1
0
class DDPGAgent(object):
    """
    General class for DDPG agents
    (policy, critic, target policy, target critic, exploration noise)
    """
    def __init__(self,
                 id,
                 num_in_pol,
                 num_out_pol,
                 num_head_pol,
                 num_in_critic,
                 hidden_dim,
                 lr,
                 lr_critic_coef,
                 use_discrete_action,
                 weight_decay,
                 discrete_exploration_scheme,
                 boltzmann_temperature,
                 logger=None):
        """
        Inputs:
            num_in_pol (int): number of dimensions for policy input
            num_out_pol (int): number of dimensions for policy output
            num_in_critic (int): number of dimensions for critic input
        """
        self.id = id

        # Instantiate the models
        self.policy = MLPNetwork(input_dim=num_in_pol,
                                 out_dim=num_out_pol * num_head_pol,
                                 hidden_dim=hidden_dim,
                                 out_fn='tanh',
                                 use_discrete_action=use_discrete_action,
                                 name="policy",
                                 logger=logger)
        self.critic = MLPNetwork(num_in_critic,
                                 1,
                                 hidden_dim=hidden_dim,
                                 out_fn='linear',
                                 use_discrete_action=use_discrete_action,
                                 name="critic",
                                 logger=logger)

        with torch.no_grad():
            self.target_policy = MLPNetwork(
                input_dim=num_in_pol,
                out_dim=num_out_pol * num_head_pol,
                hidden_dim=hidden_dim,
                out_fn='tanh',
                use_discrete_action=use_discrete_action,
                name="target_policy",
                logger=logger)
            self.target_critic = MLPNetwork(
                num_in_critic,
                1,
                hidden_dim=hidden_dim,
                out_fn='linear',
                use_discrete_action=use_discrete_action,
                name="target_critic",
                logger=logger)

        hard_update(self.target_policy, self.policy)
        hard_update(self.target_critic, self.critic)

        # Instantiate the optimizers
        self.policy_optimizer = Adam(self.policy.parameters(), lr=lr)
        self.critic_optimizer = Adam(self.critic.parameters(),
                                     lr=lr_critic_coef * lr,
                                     weight_decay=weight_decay)

        # Sets noise
        if not use_discrete_action:
            self.exploration = OUNoise(num_out_pol)
        else:
            self.exploration = None  # epsilon for eps-greedy
        self.use_discrete_action = use_discrete_action
        self.discrete_exploration_scheme = discrete_exploration_scheme
        self.boltzmann_temperature = boltzmann_temperature

        # Number of heads to the policy (to allow predicting actions of teammates for TeamMADDPG)
        self.num_out_pol = num_out_pol
        self.num_head_pol = num_head_pol

    def reset_noise(self):
        if not self.use_discrete_action:
            self.exploration.reset()

    def scale_noise(self, scale):
        if self.use_discrete_action:
            self.exploration = scale
        else:
            self.exploration.scale = scale

    def select_action(self, obs, is_exploring=False):
        """
        Take a step forward in environment for a minibatch of observations
        Inputs:
            obs (PyTorch Variable): Observations for this agent
            is_exploring (boolean): Whether or not to add exploration noise
        Outputs:
            action (PyTorch Variable): Actions for this agent
        """

        raw_action = self.policy(obs)  # shape is (batch, n_agents*act_dim)
        action = raw_action.view(-1, self.num_out_pol,
                                 self.num_head_pol)[:, :, self.id]
        if self.use_discrete_action:
            if is_exploring:
                if self.discrete_exploration_scheme == 'e-greedy':
                    action = onehot_from_logits(action, eps=self.exploration)
                elif self.discrete_exploration_scheme == 'boltzmann':
                    action = gumbel_softmax(action /
                                            self.boltzmann_temperature,
                                            hard=True)
                else:
                    raise NotImplementedError
            else:
                action = onehot_from_logits(action, eps=0.)
        else:  # continuous action
            if is_exploring:
                action += Variable(Tensor(self.exploration.noise()),
                                   requires_grad=False)
            action = action.clamp(-1., 1.)
            final_action = action  # we must remove the squeeze because we consider the batch dim as the env dim further
            # down the code even if the batch dim is one (only one env, unlike previously)
        return final_action

    def get_params(self):
        return {
            'policy': self.policy.state_dict(),
            'critic': self.critic.state_dict(),
            'target_policy': self.target_policy.state_dict(),
            'target_critic': self.target_critic.state_dict(),
            'policy_optimizer': self.policy_optimizer.state_dict(),
            'critic_optimizer': self.critic_optimizer.state_dict()
        }

    def load_params(self, params):
        self.policy.load_state_dict(params['policy'])
        self.critic.load_state_dict(params['critic'])
        self.target_policy.load_state_dict(params['target_policy'])
        self.target_critic.load_state_dict(params['target_critic'])
        self.policy_optimizer.load_state_dict(params['policy_optimizer'])
        self.critic_optimizer.load_state_dict(params['critic_optimizer'])
Exemple #2
0
class Session:
    def __init__(self):
        self.device = torch.device("cuda")

        self.log_dir = './logdir'
        self.model_dir = './model'
        ensure_dir(self.log_dir)
        ensure_dir(self.model_dir)
        self.log_name = 'train_derain'
        self.val_log_name = 'val_derain'
        logger.info('set log dir as %s' % self.log_dir)
        logger.info('set model dir as %s' % self.model_dir)

        self.test_data_path = 'testing/real_test1000.txt'  # test dataset txt file path
        self.train_data_path = 'training/dataset_small_rand.txt'  # train dataset txt file path

        self.multi_gpu = True
        self.net = SPANet().to(self.device)
        self.l1 = nn.L1Loss().to(self.device)
        self.l2 = nn.MSELoss().to(self.device)
        self.ssim = SSIM().to(self.device)

        self.step = 0
        self.save_steps = 400
        self.num_workers = 16
        self.batch_size = 32
        self.writers = {}
        self.dataloaders = {}
        self.shuffle = True
        self.opt = Adam(self.net.parameters(), lr=5e-3)
        self.sche = MultiStepLR(self.opt,
                                milestones=[5000, 15000, 30000, 50000],
                                gamma=0.1)

    def tensorboard(self, name):
        self.writers[name] = SummaryWriter(
            os.path.join(self.log_dir, name + '.events'))
        return self.writers[name]

    def write(self, name, out):
        for k, v in out.items():
            self.writers[name].add_scalar(k, v, self.step)

        out['lr'] = self.opt.param_groups[0]['lr']
        out['step'] = self.step
        outputs = ["{}:{:.4g}".format(k, v) for k, v in out.items()]
        logger.info(name + '--' + ' '.join(outputs))

    def get_dataloader(self, dataset_name, train_mode=True):
        dataset = {
            True: TrainValDataset,
            False: TestDataset,
        }[train_mode](dataset_name)
        self.dataloaders[dataset_name] = \
                    DataLoader(dataset, batch_size=self.batch_size,
                            shuffle=self.shuffle, num_workers=self.num_workers, drop_last=True)
        if train_mode:
            return iter(self.dataloaders[dataset_name])
        else:
            return self.dataloaders[dataset_name]

    def save_checkpoints(self, name):
        ckp_path = os.path.join(self.model_dir, name)
        obj = {
            'net': self.net.state_dict(),
            'clock': self.step,
            'opt': self.opt.state_dict(),
        }
        torch.save(obj, ckp_path)

    def load_checkpoints(self, name, mode='train'):
        ckp_path = os.path.join(self.model_dir, name)
        try:
            obj = torch.load(ckp_path)
        except FileNotFoundError:
            return
        self.net.load_state_dict(obj['net'])
        if mode == 'train':
            self.opt.load_state_dict(obj['opt'])
            self.step = obj['clock']
        self.sche.last_epoch = self.step

    def inf_batch(self, name, batch):
        if name == 'test':
            torch.set_grad_enabled(False)
        O, B, M = batch['O'], batch['B'], batch['M']
        O, B, M = O.to(self.device), B.to(self.device), M.to(self.device)

        mask, out = self.net(O)

        if name == 'test':
            return out.cpu().data, batch['B'], O, mask

        # loss
        l1_loss = self.l1(out, B)
        mask_loss = self.l2(mask[:, 0, :, :], M)
        ssim_loss = self.ssim(out, B)

        loss = l1_loss + (1 - ssim_loss) + mask_loss

        # log
        losses = {'l1_loss': l1_loss.item()}
        l2 = {'mask_loss': mask_loss.item()}
        losses.update(l2)
        ssimes = {'ssim_loss': ssim_loss.item()}
        losses.update(ssimes)
        allloss = {'all_loss': loss.item()}
        losses.update(allloss)
        return out, mask, M, loss, losses

    def heatmap(self, img):
        if len(img.shape) == 3:
            b, h, w = img.shape
            heat = np.zeros((b, 3, h, w)).astype('uint8')
            for i in range(b):
                heat[i, :, :, :] = np.transpose(
                    cv2.applyColorMap(img[i, :, :], cv2.COLORMAP_JET),
                    (2, 0, 1))
        else:
            b, c, h, w = img.shape
            heat = np.zeros((b, 3, h, w)).astype('uint8')
            for i in range(b):
                heat[i, :, :, :] = np.transpose(
                    cv2.applyColorMap(img[i, 0, :, :], cv2.COLORMAP_JET),
                    (2, 0, 1))
        return heat

    def save_mask(self, name, img_lists, m=0):
        data, pred, label, mask, mask_label = img_lists
        pred = pred.cpu().data

        mask = mask.cpu().data
        mask_label = mask_label.cpu().data
        data, label, pred, mask, mask_label = data * 255, label * 255, pred * 255, mask * 255, mask_label * 255
        pred = np.clip(pred, 0, 255)

        mask = np.clip(mask.numpy(), 0, 255).astype('uint8')
        mask_label = np.clip(mask_label.numpy(), 0, 255).astype('uint8')
        h, w = pred.shape[-2:]
        mask = self.heatmap(mask)
        mask_label = self.heatmap(mask_label)
        gen_num = (1, 1)

        img = np.zeros((gen_num[0] * h, gen_num[1] * 5 * w, 3))
        for img_list in img_lists:
            for i in range(gen_num[0]):
                row = i * h
                for j in range(gen_num[1]):
                    idx = i * gen_num[1] + j
                    tmp_list = [
                        data[idx], pred[idx], label[idx], mask[idx],
                        mask_label[idx]
                    ]
                    for k in range(5):
                        col = (j * 5 + k) * w
                        tmp = np.transpose(tmp_list[k], (1, 2, 0))
                        img[row:row + h, col:col + w] = tmp

        img_file = os.path.join(self.log_dir, '%d_%s.png' % (self.step, name))
        cv2.imwrite(img_file, img)
Exemple #3
0
class BaseLearner():
    def __init__(self, args, logger):
        self.logger = logger
        self.args = args

        self.actor = AGENT_REGISTRY[args.actor](args)
        self.critic = None

        self.target_actor = None
        self.target_critic = None

        self.actor_optimizer = Adam(params=self.actor.parameters(), lr=args.lr)
        self.critic_optimizer = None

        self.last_log = -self.args.log_interval - 1  # log the first run

    def forward(self, s):
        s = th.tensor(s, dtype=th.float).view(1, -1).to(self.args.device)
        actor_out = self.actor(s)
        return actor_out

    def _update_target_critic(self):
        for param, target_param in zip(self.critic.parameters(),
                                       self.target_critic.parameters()):
            target_param.data.copy_(self.args.tau * param.data +
                                    (1 - self.args.tau) * target_param.data)

    def _update_target_actor(self):
        for param, target_param in zip(self.actor.parameters(),
                                       self.target_actor.parameters()):
            target_param.data.copy_(self.args.tau * param.data +
                                    (1 - self.args.tau) * target_param.data)

    def cuda(self):
        self.actor.cuda()
        if self.critic is not None:
            self.critic.cuda()
        if self.target_critic is not None:
            self.target_critic.cuda()
        if self.target_actor is not None:
            self.target_actor.cuda()

    def save_models(self, path):
        # save actor and critic
        th.save(self.actor.state_dict(), "{}/actor.th".format(path))
        if self.critic is not None:
            th.save(self.critic.state_dict(), "{}/critic.th".format(path))
        # save target networks
        if self.target_actor is not None:
            th.save(self.target_actor.state_dict(),
                    "{}/tar_actor.th".format(path))
        if self.target_critic is not None:
            th.save(self.target_critic.state_dict(),
                    "{}/tar_critic.th".format(path))
        # save optimizers
        th.save(self.actor_optimizer.state_dict(),
                "{}/actor_opt.th".format(path))
        if self.critic_optimizer is not None:
            th.save(self.critic_optimizer.state_dict(),
                    "{}/critic_opt.th".format(path))

    def load_models(self, path):
        # actor & critic
        self.actor.load_state_dict(
            th.load("{}/actor.th".format(path),
                    map_location=lambda storage, loc: storage))
        if self.critic is not None:
            self.critic.load_state_dict(
                th.load("{}/critic.th".format(path),
                        map_location=lambda storage, loc: storage))
        # target networks
        if self.target_critic is not None:
            self.target_critic.load_state_dict(
                th.load("{}/tar_critic.th".format(path),
                        map_location=lambda storage, loc: storage))
        if self.target_actor is not None:
            self.target_actor.load_state_dict(
                th.load("{}/tar_actor.th".format(path),
                        map_location=lambda storage, loc: storage))
        # optimizers
        if self.actor_optimizer is not None:
            self.actor_optimizer.load_state_dict(
                th.load("{}/actor_opt.th".format(path),
                        map_location=lambda storage, loc: storage))
        if self.critic_optimizer is not None:
            self.critic_optimizer.load_state_dict(
                th.load("{}/critic_opt.th".format(path),
                        map_location=lambda storage, loc: storage))
class Downstream_Solver(Solver):
    ''' Handler for complete training progress'''
    def __init__(self, config, paras, task):
        super(Downstream_Solver, self).__init__(config, paras)

        # backup upstream settings
        self.upstream_paras = copy.deepcopy(paras)
        self.upstream_config = copy.deepcopy(config)
        self.task = task  # Downstream task the solver is solving

        # path and directories
        self.exp_name = self.exp_name.replace('transformer', task)
        self.logdir = self.paras.logdir.replace('transformer', task)
        self.ckpdir = self.ckpdir.replace('transformer', task)
        self.expdir = self.expdir.replace('transformer', task)
        self.dckpt = os.path.join(self.ckpdir, self.paras.dckpt)

        # model
        self.model_type = config['downstream']['model_type']
        self.load_model_list = config['downstream']['load_model_list']
        self.fine_tune = self.paras.fine_tune
        self.run_transformer = self.paras.run_transformer
        self.run_apc = self.paras.run_apc
        if self.fine_tune:
            assert (
                self.run_transformer
            ), 'Use `--run_transformer` to fine-tune the transformer model.'
            assert (not self.run_apc
                    ), 'Fine tuning only supports the transformer model.'
            assert (
                not self.paras.with_head
            ), 'Fine tuning only supports the transformer model, not with head.'
        assert (not (self.run_transformer and self.run_apc)
                ), 'Transformer and Apc can not run at the same time!'
        if self.run_transformer and self.paras.with_head:
            self.verbose('Using transformer speech representations from head.')
        elif self.run_transformer and self.fine_tune:
            self.verbose('Fine-tuning on transformer speech representations.')
        elif self.run_transformer:
            self.verbose('Using transformer speech representations.')

    def load_data(self, split='train', load='phone'):
        ''' Load date for training / testing'''
        assert (load in [
            'phone', 'cpc_phone', 'sentiment', 'speaker', 'speaker_large'
        ]), 'Unsupported dataloader!'
        if load == 'phone' or load == 'cpc_phone' or load == 'speaker_large':
            if split == 'train':
                self.verbose('Loading source data from ' +
                             str(self.config['dataloader']['train_set']) +
                             ' from ' + self.config['dataloader']['data_path'])
                if load == 'phone' or load == 'cpc_phone':
                    self.verbose('Loading phone data from ' +
                                 str(self.config['dataloader']['train_set']) +
                                 ' from ' +
                                 self.config['dataloader']['phone_path'])
            elif split == 'test':
                if load != 'cpc_phone':
                    self.verbose('Loading testing data ' +
                                 str(self.config['dataloader']['test_set']) +
                                 ' from ' +
                                 self.config['dataloader']['data_path'])
                if load == 'phone':
                    self.verbose('Loading label data ' +
                                 str(self.config['dataloader']['test_set']) +
                                 ' from ' +
                                 self.config['dataloader']['phone_path'])
                elif load == 'cpc_phone':
                    self.verbose('Loading label data from ' +
                                 self.config['dataloader']['phone_path'])
            else:
                raise NotImplementedError('Invalid `split` argument!')
        elif load == 'speaker':
            if split == 'train':
                self.verbose('Loading source data from ' +
                             str(self.config['dataloader']
                                 ['train_set']).replace('360', '100') +
                             ' from ' + self.config['dataloader']['data_path'])
            elif split == 'test':
                self.verbose('Loading testing data ' +
                             str(self.config['dataloader']
                                 ['test_set']).replace('360', '100') +
                             ' from ' + self.config['dataloader']['data_path'])
            else:
                raise NotImplementedError('Invalid `split` argument!')
        elif load == 'sentiment':
            target = self.config['dataloader']['sentiment_config']['dataset']
            sentiment_path = self.config['dataloader']['sentiment_config'][
                target]['path']
            self.verbose(f'Loading {split} data from {sentiment_path}')
        else:
            raise NotImplementedError('Unsupported downstream tasks.')

        setattr(self, 'dataloader', get_Dataloader(split, load=load, use_gpu=self.paras.gpu, \
                run_mam=self.run_transformer, mam_config=self.transformer_config, \
                **self.config['dataloader']))

    def set_model(self, inference=False):
        input_dim = int(self.config['downstream'][self.model_type]['input_dim']) if \
                    self.config['downstream'][self.model_type]['input_dim'] != 'None' else None
        if 'transformer' in self.task:
            self.upstream_tester = Tester(self.upstream_config,
                                          self.upstream_paras)
            if self.fine_tune and inference:
                self.upstream_tester.load = False  # During inference on fine-tuned model, load with `load_downstream_model()`
            self.upstream_tester.set_model(
                inference=True, with_head=self.paras.with_head
            )  # inference should be set True so upstream solver won't create optimizer
            self.dr = self.upstream_tester.dr
            if input_dim is None:
                input_dim = self.transformer_config['hidden_size']
        elif 'apc' in self.task:
            self.apc = get_apc_model(path=self.paras.apc_path)
            if input_dim is None:
                input_dim = self.transformer_config[
                    'hidden_size']  # use identical dim size for fair comparison
        elif 'baseline' in self.task:
            if input_dim is None:
                if 'input_dim' in self.transformer_config:
                    input_dim = self.transformer_config['input_dim']
                else:
                    raise ValueError(
                        'Please update your config file to include the attribute `input_dim`.'
                    )
        else:
            raise NotImplementedError('Invalid Task!')

        if self.model_type == 'linear':
            self.classifier = LinearClassifier(
                input_dim=input_dim,
                class_num=self.dataloader.dataset.class_num,
                dconfig=self.config['downstream']['linear']).to(self.device)
        elif self.model_type == 'rnn':
            self.classifier = RnnClassifier(
                input_dim=input_dim,
                class_num=self.dataloader.dataset.class_num,
                dconfig=self.config['downstream']['rnn']).to(self.device)

        if not inference and self.fine_tune:
            # Setup Fine tune optimizer
            self.upstream_tester.transformer.train()
            param_optimizer = list(
                self.upstream_tester.transformer.named_parameters()) + list(
                    self.classifier.named_parameters())
            self.optimizer = get_optimizer(
                params=param_optimizer,
                lr=self.learning_rate,
                warmup_proportion=self.config['optimizer']
                ['warmup_proportion'],
                training_steps=self.total_steps)
        elif not inference:
            self.optimizer = Adam(self.classifier.parameters(),
                                  lr=self.learning_rate,
                                  betas=(0.9, 0.999))
            self.classifier.train()
        else:
            self.classifier.eval()

        if self.load:  # This will be set to True by default when Tester is running set_model()
            self.load_downstream_model(inference=inference)

    def save_model(self, name, model_all=True, assign_name=None):
        if model_all:
            all_states = {
                'Classifier':
                self.classifier.state_dict(),
                'Transformer':
                self.upstream_tester.transformer.state_dict()
                if self.fine_tune else None,
                'Optimizer':
                self.optimizer.state_dict(),
                'Global_step':
                self.global_step,
                'Settings': {
                    'Config': self.config,
                    'Paras': self.paras,
                },
            }
        else:
            all_states = {
                'Classifier': self.classifier.state_dict(),
                'Settings': {
                    'Config': self.config,
                    'Paras': self.paras,
                },
            }

        if assign_name is not None:
            model_path = f'{self.expdir}/{assign_name}.ckpt'
            torch.save(all_states, model_path)
            return

        new_model_path = '{}/{}-{}.ckpt'.format(self.expdir, name,
                                                self.global_step)
        torch.save(all_states, new_model_path)
        self.model_kept.append(new_model_path)

        if len(self.model_kept) >= self.max_keep:
            os.remove(self.model_kept[0])
            self.model_kept.pop(0)

    def load_downstream_model(self, inference=False):
        self.verbose('Load model from {}'.format(self.dckpt))
        all_states = torch.load(self.dckpt, map_location='cpu')

        if 'Classifier' in self.load_model_list:
            try:
                self.classifier.load_state_dict(all_states['Classifier'])
                self.verbose('[Classifier] - Loaded')
            except:
                self.verbose('[Classifier - X]')

        if 'Optimizer' in self.load_model_list and not inference:
            try:
                self.optimizer.load_state_dict(all_states['Optimizer'])
                for state in self.optimizer.state.values():
                    for k, v in state.items():
                        if torch.is_tensor(v):
                            state[k] = v.cuda()
                self.verbose('[Optimizer] - Loaded')
            except:
                self.verbose('[Optimizer - X]')

        if 'Global_step' in self.load_model_list:
            try:
                self.global_step = all_states['Global_step']
                self.verbose('[Global_step] - Loaded')
            except:
                self.verbose('[Global_step - X]')

        if self.fine_tune:
            try:
                self.verbose(
                    '@ Downstream, [Fine-Tuned Transformer] - Loading with Upstream Tester...'
                )
                self.upstream_tester.load_model(inference=inference,
                                                from_path=self.ckpt)
                self.verbose('@ Downstream, [Fine-Tuned Transformer] - Loaded')
            except:
                self.verbose('[Fine-Tuned Transformer] - X')

        self.verbose('Model loading complete!')
Exemple #5
0
class Train(Procedure):
    def __init__(self, params, model_file_path=None):
        super().__init__(params, is_eval=False)
        # wait for creating threads
        time.sleep(10)
        cur_time = int(time.time())
        if model_file_path is None:
            train_dir = os.path.join(self.params.model_root,
                                     'train_%d' % (cur_time))
        else:
            # model_file_path is expected to be train_dir/model/model_name
            train_dir = os.path.dirname(
                os.path.dirname(os.path.abspath(model_file_path)))
        if not os.path.exists(train_dir):
            os.makedirs(train_dir)
        self.model_dir = os.path.join(train_dir, 'model')
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)

        # dump the params
        param_path = os.path.join(train_dir, 'params_{}.json'.format(cur_time))
        print("Dump hyper-parameters to {}.".format(param_path))
        params.save(param_path)

        self.model_file_path = model_file_path
        self.summary_writer = tf.summary.FileWriter(train_dir)
        self.summary_flush_interval = self.params.summary_flush_interval
        self.print_interval = self.params.print_interval
        self.save_interval = self.params.save_interval

    def _get_save_path(self, iter):
        cur_time = time.time()
        if self.params.is_coverage:
            prefix = 'coverage_model_{}_{}'
            param_prefix = 'coverage_params_{}_{}'
        else:
            prefix = 'model_{}_{}'
            param_prefix = 'params_{}_{}'

        if self.params.train_rl:
            prefix = 'rl_' + prefix
            param_prefix = 'rl_' + param_prefix
        model_save_path = os.path.join(self.model_dir,
                                       prefix.format(iter, cur_time))
        param_save_path = os.path.join(self.model_dir,
                                       param_prefix.format(iter, cur_time))
        return model_save_path, param_save_path

    def save_model(self, iter, running_avg_loss, model_save_path):
        state = {
            'iter': iter,
            'encoder_state_dict': self.model.encoder.state_dict(),
            'decoder_state_dict': self.model.decoder.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'current_loss': running_avg_loss
        }
        torch.save(state, model_save_path)

    def setup_train(self, model_file_path):
        # check params
        rl_weight = getattr(self.params, "rl_weight", 0.0)
        if self.params.train_rl and rl_weight == 0.0:
            raise ValueError(
                "Train RL is True, while rl_weight is 0.0. Contradiction!")

        self.model = PointerEncoderDecoder(self.params,
                                           model_file_path,
                                           pad_id=self.pad_id)
        initial_lr = self.params.lr if not self.params.is_coverage else self.params.lr_coverage
        optim_name = self.params.optim.lower()
        if optim_name == "adam":
            self.optimizer = Adam(self.model.parameters, lr=initial_lr)
        else:
            raise ValueError("Unknow optim {}".format(optim_name))

        start_iter, start_loss = 0, 0

        if model_file_path is not None:
            state = torch.load(model_file_path,
                               map_location=lambda storage, location: storage)
            start_iter = state['iter']
            start_loss = state['current_loss']

            train_rl = self.params.train_rl
            reoptim = self.params.reoptim
            if not train_rl and reoptim:
                raise ValueError("Not training rl but recreate the optimizer")

            # We need not to load the checkpoint if we use coverage to retrain
            if not self.params.is_coverage and not reoptim:
                print("Load the optimizer...")
                sys.stdout.flush()
                self.optimizer.load_state_dict(state['optimizer'])
                if utils.use_cuda(self.params.device):
                    device = torch.device(self.params.device)
                    for state in self.optimizer.state.values():
                        for k, v in state.items():
                            if torch.is_tensor(v):
                                state[k] = v.to(device)

        return start_iter, start_loss

    def train_one_batch(self, batch, iter):
        return self.infer_one_batch(batch, iter, is_eval=False)

    def train(self, n_iters=None, eval=False):
        """
        :param n_iters: the iterations of training process
        :param model_file_path: the stored model file
        :return:
            do not return anything, but will print logs and store models
        """
        eval_processes = []
        if n_iters == None:
            n_iters = self.params.max_iterations
        iter, running_avg_loss = self.setup_train(self.model_file_path)
        start_iter = iter
        total_iter = n_iters - start_iter
        start = time.time()
        start_time = start
        print("start training.")
        sys.stdout.flush()
        loss_total = 0
        reward_total = 0

        while iter < n_iters:
            batch = self.batcher.next_batch()
            loss, reward = self.train_one_batch(batch, iter)

            running_avg_loss = utils.calc_running_avg_loss(
                loss, running_avg_loss, self.summary_writer, iter)
            loss_total += loss
            reward_total += reward
            iter += 1

            if iter % self.summary_flush_interval == 0:
                self.summary_writer.flush()
            if iter % self.print_interval == 0:
                elapse, remain = utils.time_since(
                    start_time, (iter - start_iter) / total_iter)
                iter_num = iter - start_iter
                print(
                    'Train steps %d, seconds for %d batch: %.2f , loss: %f, reward: %f, elapse: %s, remain: %s'
                    % (iter, self.print_interval, time.time() - start,
                       loss_total / iter_num, reward_total / iter_num, elapse,
                       remain))
                sys.stdout.flush()
                start = time.time()
                if np.isnan(loss) or np.isnan(running_avg_loss):
                    raise ValueError("Loss becomes nan")

            if iter % self.save_interval == 0:
                model_save_path, param_save_path = self._get_save_path(iter)
                self.save_model(iter, running_avg_loss, model_save_path)
                self.params.save(param_save_path)

                if eval:
                    kwargs = {
                        "params": self.params,
                        "model_path": model_save_path,
                        "ngram_filter": False,
                        "data_file_prefix": "valid."
                    }
                    # p = mp.Process(target=PRSum.eval_raw, kwargs=kwargs)
                    # decode instead of evaluate
                    p = mp.Process(target=PRSum.decode_raw, kwargs=kwargs)
                    eval_processes.append(p)
                    p.start()

        for cur_p in eval_processes:
            cur_p.join()
        print("end training.")
        F.relu(x)

        # 4、输出层
        out = self.fc2(x)

        return F.log_softmax(out, dim=-1)


model = MnistModel()
optimizer = Adam(model.parameters(), lr=0.001)

if os.path.exists("./model/model.pkl"):
    # 加载模型参数
    model.load_state_dict(torch.load("./model/model.pkl"))
    # 加载优化器的参数
    optimizer.load_state_dict(torch.load("./model/optimizer.pkl"))


def train(epoch):
    """实现训练过程"""
    mode = True
    model.train(mode=mode)  # 模型设置为训练模式

    data_loader = get_dataloader()
    for idx, (input, target) in enumerate(data_loader):  # 每一轮里面的数据进行遍历
        optimizer.zero_grad()  # 梯度清零
        output = model(input)  # 调用模型,得到预测值
        loss = F.nll_loss(output, target)  # 得到损失
        loss.backward()  # 反向传播
        optimizer.step()  # 梯度更新
        if idx % 10 == 0:
Exemple #7
0
def main(args):
    # Use CUDA?
    args.cuda = args.cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    # Load datasets and build data loaders
    val_dataset = MotionDataset(args.val_data,
                                fps=args.fps,
                                mapper=args.mapper)
    val_actions = val_dataset.actions.keys()

    train_dataset = MotionDataset(args.train_data,
                                  keep_actions=val_actions,
                                  fps=args.fps,
                                  offset=args.offset,
                                  mapper=args.mapper)
    train_actions = train_dataset.actions.keys()

    # with open('a.txt', 'w') as f1, open('b.txt', 'w') as f2:
    #     f1.write('\n'.join(map(str, train_dataset.actions.keys())))
    #     f2.write('\n'.join(map(str, val_dataset.actions.keys())))

    assert len(train_actions) == len(val_actions), \
        "Train and val sets should have same number of actions ({} vs {})".format(
            len(train_actions), len(val_actions))

    in_size, out_size = train_dataset.get_data_size()

    if args.balance == 'none':
        sampler = RandomSampler(train_dataset)
    else:
        weights = train_dataset.get_weights()
        sampler = WeightedRandomSampler(weights, len(weights))

    train_loader = DataLoader(train_dataset,
                              batch_size=1,
                              sampler=sampler,
                              num_workers=1,
                              pin_memory=args.cuda)
    val_loader = DataLoader(val_dataset,
                            batch_size=1,
                            shuffle=False,
                            num_workers=1,
                            pin_memory=args.cuda)

    # Build the model
    model = MotionModel(in_size,
                        out_size,
                        hidden=args.hd,
                        dropout=args.dropout,
                        bidirectional=args.bidirectional,
                        stack=args.stack,
                        layers=args.layers,
                        embed=args.embed)
    if args.cuda:
        model.cuda()

    # Create the optimizer and start training-eval loop
    if args.optim == 'adam':
        optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
    elif args.optim == 'sgd':
        optimizer = SGD(model.parameters(), lr=args.lr, weight_decay=args.wd)

    # Resume training?
    if args.resume:
        run_dir = args.resume
        last_checkpoint = get_last_checkpoint(run_dir)
        checkpoint = torch.load(last_checkpoint)
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        best_ap = checkpoint['best_micro_ap']
        start_epoch = checkpoint['epoch'] + 1
    else:
        best_ap = 0
        start_epoch = 1
        parameters = vars(args)

        train_fname = os.path.splitext(os.path.basename(args.train_data))[0]
        val_fname = os.path.splitext(os.path.basename(args.val_data))[0]

        # Create the run directory and log file
        run_name = 'segment_tr-{1[train]}_vl-{1[val]}_' \
                   'bi{0[bidirectional]}_' \
                   'emb{0[embed]}_' \
                   'h{0[hd]}_' \
                   's{0[stack]}_' \
                   'l{0[layers]}_' \
                   '{0[head]}_' \
                   'a{0[accumulate]}_' \
                   'c{0[clip_norm]}_' \
                   'd{0[dropout]}_' \
                   'lr{0[lr]}_' \
                   'wd{0[wd]}_' \
                   'e{0[epochs]}_' \
                   'f{0[fps]:g}_' \
                   'o-{0[offset]}_' \
                   'opt-{0[optim]}_' \
                   'ls{0[label_smoothing]}_' \
                   'bal-{0[balance]}'.format(parameters, dict(train=train_fname, val=val_fname))

        runs_parent_dir = 'debug' if args.debug else args.run_dir
        run_dir = os.path.join(runs_parent_dir, run_name)
        if not os.path.exists(run_dir):
            os.makedirs(run_dir)
        elif not args.debug:
            return

        params = pd.DataFrame(
            parameters, index=[0])  # an index is mandatory for a single line
        params_fname = os.path.join(run_dir, 'params.csv')
        params.to_csv(params_fname, index=False)

        with pd.option_context('display.width',
                               None), pd.option_context('max_columns', None):
            print(params)

    log_file = os.path.join(run_dir, 'log.txt')
    args.log = open(log_file, 'a+')

    progress_bar = trange(start_epoch,
                          args.epochs + 1,
                          initial=start_epoch,
                          disable=args.no_progress)
    for epoch in progress_bar:
        progress_bar.set_description(
            'TRAIN [CurBestAP={:4.3f}]'.format(best_ap))
        train(train_loader, model, optimizer, epoch, args)

        progress_bar.set_description(
            'EVAL [CurBestAP={:4.3f}]'.format(best_ap))
        metrics = evaluate(val_loader, model, args)

        print('Eval Epoch {}: '
              'Loss={:6.4f} '
              'microAP={:4.3f} '
              'macroAP={:4.3f} '
              'F1={:4.3f} '
              'microMultiF1={:4.3f} '
              'macroMultiF1={:4.3f}'.format(epoch, *metrics),
              file=args.log,
              flush=True)

        current_micro_ap = metrics[1]
        is_best = current_micro_ap > best_ap
        best_ap = max(best_ap, current_micro_ap)

        # SAVE MODEL
        if args.keep:
            fname = 'epoch_{:02d}.pth'.format(epoch)
        else:
            fname = 'last_checkpoint.pth'

        fname = os.path.join(run_dir, fname)
        save_checkpoint(
            {
                'epoch': epoch,
                'best_micro_ap': best_ap,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, is_best, fname)
Exemple #8
0
def main(args):

    devices = ["cuda" if torch.cuda.is_available() else "cpu"]

    logging.info("Start time: {}".format(str(datetime.now())))

    melkwargs = {
        "n_fft": args.n_fft,
        "power": 1,
        "hop_length": args.hop_length,
        "win_length": args.win_length,
    }

    transforms = torch.nn.Sequential(
        torchaudio.transforms.Spectrogram(**melkwargs),
        LinearToMel(
            sample_rate=args.sample_rate,
            n_fft=args.n_fft,
            n_mels=args.n_freq,
            fmin=args.f_min,
        ),
        NormalizeDB(min_level_db=args.min_level_db),
    )

    train_dataset, val_dataset = split_process_ljspeech(args, transforms)

    loader_training_params = {
        "num_workers": args.workers,
        "pin_memory": False,
        "shuffle": True,
        "drop_last": False,
    }
    loader_validation_params = loader_training_params.copy()
    loader_validation_params["shuffle"] = False

    collate_fn = collate_factory(args)

    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        collate_fn=collate_fn,
        **loader_training_params,
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        collate_fn=collate_fn,
        **loader_validation_params,
    )

    n_classes = 2**args.n_bits if args.loss == "crossentropy" else 30

    model = WaveRNN(
        upsample_scales=args.upsample_scales,
        n_classes=n_classes,
        hop_length=args.hop_length,
        n_res_block=args.n_res_block,
        n_rnn=args.n_rnn,
        n_fc=args.n_fc,
        kernel_size=args.kernel_size,
        n_freq=args.n_freq,
        n_hidden=args.n_hidden_melresnet,
        n_output=args.n_output_melresnet,
    )

    if args.jit:
        model = torch.jit.script(model)

    model = torch.nn.DataParallel(model)
    model = model.to(devices[0], non_blocking=True)

    n = count_parameters(model)
    logging.info(f"Number of parameters: {n}")

    # Optimizer
    optimizer_params = {
        "lr": args.learning_rate,
    }

    optimizer = Adam(model.parameters(), **optimizer_params)

    criterion = LongCrossEntropyLoss(
    ) if args.loss == "crossentropy" else MoLLoss()

    best_loss = 10.0

    if args.checkpoint and os.path.isfile(args.checkpoint):
        logging.info(f"Checkpoint: loading '{args.checkpoint}'")
        checkpoint = torch.load(args.checkpoint)

        args.start_epoch = checkpoint["epoch"]
        best_loss = checkpoint["best_loss"]

        model.load_state_dict(checkpoint["state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer"])

        logging.info(
            f"Checkpoint: loaded '{args.checkpoint}' at epoch {checkpoint['epoch']}"
        )
    else:
        logging.info("Checkpoint: not found")

        save_checkpoint(
            {
                "epoch": args.start_epoch,
                "state_dict": model.state_dict(),
                "best_loss": best_loss,
                "optimizer": optimizer.state_dict(),
            },
            False,
            args.checkpoint,
        )

    for epoch in range(args.start_epoch, args.epochs):

        train_one_epoch(
            model,
            criterion,
            optimizer,
            train_loader,
            devices[0],
            epoch,
        )

        if not (epoch + 1) % args.print_freq or epoch == args.epochs - 1:

            sum_loss = validate(model, criterion, val_loader, devices[0],
                                epoch)

            is_best = sum_loss < best_loss
            best_loss = min(sum_loss, best_loss)
            save_checkpoint(
                {
                    "epoch": epoch + 1,
                    "state_dict": model.state_dict(),
                    "best_loss": best_loss,
                    "optimizer": optimizer.state_dict(),
                },
                is_best,
                args.checkpoint,
            )

    logging.info(f"End time: {datetime.now()}")
Exemple #9
0
class DDPGAgent(object):
    """
    General class for DDPG agents (policy, critic, target policy, target
    critic, exploration noise)
    """
    def __init__(self, num_in_pol, num_out_pol, num_in_critic, hidden_dim_actor=120,
    hidden_dim_critic=64,lr_actor=0.01,lr_critic=0.01,batch_size=64,
    max_episode_len=100,tau=0.02,gamma = 0.99,agent_name='one', discrete_action=False):
        """
        Inputs:
            num_in_pol (int): number of dimensions for policy input
            num_out_pol (int): number of dimensions for policy output
            num_in_critic (int): number of dimensions for critic input
        """
        self.policy = Actor(num_in_pol, num_out_pol,
                                 hidden_dim=hidden_dim_actor,
                                 discrete_action=discrete_action)
        self.critic = Critic(num_in_pol, 1,num_out_pol,
                                 hidden_dim=hidden_dim_critic)
        self.target_policy = Actor(num_in_pol, num_out_pol,
                                        hidden_dim=hidden_dim_actor,
                                        discrete_action=discrete_action)
        self.target_critic = Critic(num_in_pol, 1,num_out_pol,
                                        hidden_dim=hidden_dim_critic)
        hard_update(self.target_policy, self.policy)
        hard_update(self.target_critic, self.critic)
        self.policy_optimizer = Adam(self.policy.parameters(), lr=lr_actor)
        self.critic_optimizer = Adam(self.critic.parameters(), lr=lr_critic,weight_decay=0)
        
        self.policy = self.policy.float()
        self.critic = self.critic.float()
        self.target_policy = self.target_policy.float()
        self.target_critic = self.target_critic.float()

        self.agent_name = agent_name
        self.gamma = gamma
        self.tau = tau
        self.batch_size = batch_size
        #self.replay_buffer = ReplayBuffer(1e7)
        self.replay_buffer = ReplayBufferOption(500000,self.batch_size,12)
        self.max_replay_buffer_len = batch_size * max_episode_len
        self.replay_sample_index = None
        self.niter = 0
        self.eps = 5.0
        self.eps_decay = 1/(250*5)

        self.exploration = OUNoise(num_out_pol)
        self.discrete_action = discrete_action

        self.num_history = 2
        self.states = []
        self.actions = []
        self.rewards = []
        self.next_states = []
        self.dones = []

    def reset_noise(self):
        if not self.discrete_action:
            self.exploration.reset()

    def scale_noise(self, scale):
        if self.discrete_action:
            self.exploration = scale
        else:
            self.exploration.scale = scale

    def act(self, obs, explore=False):
        """
        Take a step forward in environment for a minibatch of observations
        Inputs:
            obs : Observations for this agent
            explore (boolean): Whether or not to add exploration noise
        Outputs:
            action (PyTorch Variable): Actions for this agent
        """
        #obs = obs.reshape(1,48)
        state = Variable(torch.Tensor(obs),requires_grad=False)

        self.policy.eval()
        with torch.no_grad():
            action = self.policy(state)
        self.policy.train()
        # continuous action
        if explore:
            action += Variable(Tensor(self.eps * self.exploration.sample()),requires_grad=False)
            action = torch.clamp(action, min=-1, max=1)
        return action

    def step(self, agent_id, state, action, reward, next_state, done,t_step):
        
        self.states.append(state)
        self.actions.append(action)
        self.rewards.append(reward)
        self.next_states.append(next_state)
        self.dones.append(done)

        #self.replay_buffer.add(state, action, reward, next_state, done)
        if t_step % self.num_history == 0:
            # Save experience / reward
            
            self.replay_buffer.add(self.states, self.actions, self.rewards, self.next_states, self.dones)
            self.states = []
            self.actions = []
            self.rewards = []
            self.next_states = []
            self.dones = []

        # Learn, if enough samples are available in memory
        if len(self.replay_buffer) > self.batch_size:
            
            obs, acs, rews, next_obs, don = self.replay_buffer.sample()     
            self.update(agent_id ,obs,  acs, rews, next_obs, don,t_step)
        


    def update(self, agent_id, obs, acs, rews, next_obs, dones ,t_step, logger=None):
    
        obs = torch.from_numpy(obs).float()
        acs = torch.from_numpy(acs).float()
        rews = torch.from_numpy(rews[:,agent_id]).float()
        next_obs = torch.from_numpy(next_obs).float()
        dones = torch.from_numpy(dones[:,agent_id]).float()

        acs = acs.view(-1,2)
                
        # --------- update critic ------------ #        
        self.critic_optimizer.zero_grad()
        
        all_trgt_acs = self.target_policy(next_obs) 
    
        target_value = (rews + self.gamma *
                        self.target_critic(next_obs,all_trgt_acs) *
                        (1 - dones)) 
        
        actual_value = self.critic(obs,acs)
        vf_loss = MSELoss(actual_value, target_value.detach())

        # Minimize the loss
        vf_loss.backward()
        #torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 1)
        self.critic_optimizer.step()

        # --------- update actor --------------- #
        self.policy_optimizer.zero_grad()

        if self.discrete_action:
            curr_pol_out = self.policy(obs)
            curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True)
        else:
            curr_pol_out = self.policy(obs)
            curr_pol_vf_in = curr_pol_out


        pol_loss = -self.critic(obs,curr_pol_vf_in).mean()
        #pol_loss += (curr_pol_out**2).mean() * 1e-3
        pol_loss.backward()

        torch.nn.utils.clip_grad_norm_(self.policy.parameters(), 1)
        self.policy_optimizer.step()

        self.update_all_targets()
        self.eps -= self.eps_decay
        self.eps = max(self.eps, 0)
        

        if logger is not None:
            logger.add_scalars('agent%i/losses' % self.agent_name,
                               {'vf_loss': vf_loss,
                                'pol_loss': pol_loss},
                               self.niter)

    def update_all_targets(self):
        """
        Update all target networks (called after normal updates have been
        performed for each agent)
        """
        
        soft_update(self.critic, self.target_critic, self.tau)
        soft_update(self.policy, self.target_policy, self.tau)
   
    def get_params(self):
        return {'policy': self.policy.state_dict(),
                'critic': self.critic.state_dict(),
                'target_policy': self.target_policy.state_dict(),
                'target_critic': self.target_critic.state_dict(),
                'policy_optimizer': self.policy_optimizer.state_dict(),
                'critic_optimizer': self.critic_optimizer.state_dict()}

    def load_params(self, params):
        self.policy.load_state_dict(params['policy'])
        self.critic.load_state_dict(params['critic'])
        self.target_policy.load_state_dict(params['target_policy'])
        self.target_critic.load_state_dict(params['target_critic'])
        self.policy_optimizer.load_state_dict(params['policy_optimizer'])
        self.critic_optimizer.load_state_dict(params['critic_optimizer'])
Exemple #10
0
def main(args: argparse.Namespace):
    logger = CompleteLogger(args.log, args.phase)
    print(args)

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    cudnn.benchmark = True

    # Data loading code
    source_dataset = datasets.__dict__[args.source]
    train_source_dataset = source_dataset(
        root=args.source_root,
        transforms=T.Compose([
            T.RandomResizedCrop(size=args.train_size,
                                ratio=args.resize_ratio,
                                scale=(0.5, 1.)),
            T.ColorJitter(brightness=0.3, contrast=0.3),
            T.RandomHorizontalFlip(),
            T.NormalizeAndTranspose(),
        ]),
    )
    train_source_loader = DataLoader(train_source_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     pin_memory=True,
                                     drop_last=True)

    target_dataset = datasets.__dict__[args.target]
    train_target_dataset = target_dataset(
        root=args.target_root,
        transforms=T.Compose([
            T.RandomResizedCrop(size=args.train_size,
                                ratio=(2., 2.),
                                scale=(0.5, 1.)),
            T.RandomHorizontalFlip(),
            T.NormalizeAndTranspose(),
        ]),
    )
    train_target_loader = DataLoader(train_target_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     pin_memory=True,
                                     drop_last=True)
    val_target_dataset = target_dataset(
        root=args.target_root,
        split='val',
        transforms=T.Compose([
            T.Resize(image_size=args.test_input_size,
                     label_size=args.test_output_size),
            T.NormalizeAndTranspose(),
        ]),
    )
    val_target_loader = DataLoader(val_target_dataset,
                                   batch_size=1,
                                   shuffle=False,
                                   pin_memory=True)

    train_source_iter = ForeverDataIterator(train_source_loader)
    train_target_iter = ForeverDataIterator(train_target_loader)

    # create model
    num_classes = train_source_dataset.num_classes
    model = models.__dict__[args.arch](num_classes=num_classes).to(device)
    discriminator = Discriminator(num_classes=num_classes).to(device)

    # define optimizer and lr scheduler
    optimizer = SGD(model.get_parameters(),
                    lr=args.lr,
                    momentum=args.momentum,
                    weight_decay=args.weight_decay)
    optimizer_d = Adam(discriminator.parameters(),
                       lr=args.lr_d,
                       betas=(0.9, 0.99))
    lr_scheduler = LambdaLR(
        optimizer, lambda x: args.lr *
        (1. - float(x) / args.epochs / args.iters_per_epoch)**(args.lr_power))
    lr_scheduler_d = LambdaLR(
        optimizer_d, lambda x:
        (1. - float(x) / args.epochs / args.iters_per_epoch)**(args.lr_power))

    # optionally resume from a checkpoint
    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model.load_state_dict(checkpoint['model'])
        discriminator.load_state_dict(checkpoint['discriminator'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        optimizer_d.load_state_dict(checkpoint['optimizer_d'])
        lr_scheduler_d.load_state_dict(checkpoint['lr_scheduler_d'])
        args.start_epoch = checkpoint['epoch'] + 1

    # define loss function (criterion)
    criterion = torch.nn.CrossEntropyLoss(
        ignore_index=args.ignore_label).to(device)
    dann = DomainAdversarialEntropyLoss(discriminator)
    interp_train = nn.Upsample(size=args.train_size[::-1],
                               mode='bilinear',
                               align_corners=True)
    interp_val = nn.Upsample(size=args.test_output_size[::-1],
                             mode='bilinear',
                             align_corners=True)

    # define visualization function
    decode = train_source_dataset.decode_target

    def visualize(image, pred, label, prefix):
        """
        Args:
            image (tensor): 3 x H x W
            pred (tensor): C x H x W
            label (tensor): H x W
            prefix: prefix of the saving image
        """
        image = image.detach().cpu().numpy()
        pred = pred.detach().max(dim=0)[1].cpu().numpy()
        label = label.cpu().numpy()
        for tensor, name in [
            (Image.fromarray(np.uint8(DeNormalizeAndTranspose()(image))),
             "image"), (decode(label), "label"), (decode(pred), "pred")
        ]:
            tensor.save(logger.get_image_path("{}_{}.png".format(prefix,
                                                                 name)))

    if args.phase == 'test':
        confmat = validate(val_target_loader, model, interp_val, criterion,
                           visualize, args)
        print(confmat)
        return

    # start training
    best_iou = 0.
    for epoch in range(args.start_epoch, args.epochs):
        logger.set_epoch(epoch)
        print(lr_scheduler.get_lr(), lr_scheduler_d.get_lr())
        # train for one epoch
        train(train_source_iter, train_target_iter, model, interp_train,
              criterion, dann, optimizer, lr_scheduler, optimizer_d,
              lr_scheduler_d, epoch, visualize if args.debug else None, args)

        # evaluate on validation set
        confmat = validate(val_target_loader, model, interp_val, criterion,
                           None, args)
        print(confmat.format(train_source_dataset.classes))
        acc_global, acc, iu = confmat.compute()

        # calculate the mean iou over partial classes
        indexes = [
            train_source_dataset.classes.index(name)
            for name in train_source_dataset.evaluate_classes
        ]
        iu = iu[indexes]
        mean_iou = iu.mean()

        # remember best acc@1 and save checkpoint
        torch.save(
            {
                'model': model.state_dict(),
                'discriminator': discriminator.state_dict(),
                'optimizer': optimizer.state_dict(),
                'optimizer_d': optimizer_d.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
                'lr_scheduler_d': lr_scheduler_d.state_dict(),
                'epoch': epoch,
                'args': args
            }, logger.get_checkpoint_path(epoch))
        if mean_iou > best_iou:
            shutil.copy(logger.get_checkpoint_path(epoch),
                        logger.get_checkpoint_path('best'))
        best_iou = max(best_iou, mean_iou)
        print("Target: {} Best: {}".format(mean_iou, best_iou))

    logger.close()
Exemple #11
0
class Session:
    def __init__(self):
        self.log_dir = settings.log_dir
        self.model_dir = settings.model_dir
        ensure_dir(settings.log_dir)
        ensure_dir(settings.model_dir)
        logger.info('set log dir as %s' % settings.log_dir)
        logger.info('set model dir as %s' % settings.model_dir)

        self.backbone = ResBackbone().cuda()
        self.classification_model = Delf_classification(
            settings.num_classes
        ).cuda(
        ) if settings.classification_model == 'delf' else Resnet_classification(
            settings.num_classes).cuda()

        if settings.classification_model == 'delf':
            for para in self.backbone.parameters():
                para.requires_grad = False

        self.backbone = nn.DataParallel(self.backbone,
                                        device_ids=range(settings.num_gpu))
        self.classification_model = nn.DataParallel(self.classification_model,
                                                    device_ids=range(
                                                        settings.num_gpu))

        self.crit = nn.CrossEntropyLoss().cuda()

        self.epoch_count = 0
        self.step = 0
        self.save_steps = settings.save_steps
        self.num_workers = settings.num_workers
        self.batch_size = settings.batch_size
        self.writers = {}
        self.dataloaders = {}

        if settings.classification_model == 'delf':
            parameters = list(self.classification_model.parameters())
        elif settings.classification_model == 'res':
            parameters = list(self.backbone.parameters()) + list(
                self.classification_model.parameters())

        self.opt = Adam(parameters,
                        lr=settings.lr,
                        weight_decay=1,
                        amsgrad=True)
        self.sche = MultiStepLR(self.opt,
                                milestones=settings.iter_sche,
                                gamma=0.1)

    def tensorboard(self, name):
        self.writers[name] = SummaryWriter(
            os.path.join(self.log_dir, name + '.events'))
        return self.writers[name]

    def write(self, name, out):
        for k, v in out.items():
            self.writers[name].add_scalar(name + '/' + k, v, self.step)

        out['lr'] = self.opt.param_groups[0]['lr']
        out['step'] = self.step
        out['eooch_count'] = self.epoch_count
        outputs = ["{}:{:.4g}".format(k, v) for k, v in out.items()]
        logger.info(name + '--' + ' '.join(outputs))

    def get_dataloader(self,
                       dataset_name,
                       keyid_cat_catindex,
                       words,
                       use_iter=True):
        dataset = QuickdrawDataset(dataset_name, keyid_cat_catindex, words)

        dataloader_this = \
                    DataLoader(dataset,
                               batch_size=self.batch_size,
                               shuffle=True,
                               num_workers=self.num_workers,
                               drop_last=False if dataset_name == 'test' else True)
        if use_iter:
            return iter(dataloader_this)
        else:
            return dataloader_this

    def save_checkpoints(self, name):
        ckp_path = os.path.join(self.model_dir, name)
        obj = {
            'backbone': self.backbone.state_dict(),
            'classification_model': self.classification_model.state_dict(),
            'clock': self.step,
            'epoch_count': self.epoch_count,
            'opt': self.opt.state_dict(),
        }
        torch.save(obj, ckp_path)

    def load_checkpoints(self, name):
        ckp_path = os.path.join(self.model_dir, name)
        try:
            obj = torch.load(ckp_path)
            print('load checkpoint: %s' % ckp_path)
        except FileNotFoundError:
            print('Find no checkpoint, reinitialize one!')
            return
        self.backbone.load_state_dict(obj['backbone'])

        self.classification_model.load_state_dict(obj['classification_model'])
        self.opt.load_state_dict(obj['opt'])
        self.step = obj['clock']
        self.epoch_count = obj['epoch_count']
        self.sche.last_epoch = self.step

    def load_checkpoints_delf_init(self, name):
        ckp_path = os.path.join(self.model_dir, name)
        obj = torch.load(ckp_path)
        self.backbone.load_state_dict(obj['backbone'])

    def inf_batch(self, batch, return_correct=False):
        drawing, label = batch['drawing'], batch['word']
        drawing, label = drawing.cuda(), label.cuda()

        x = self.backbone(drawing)

        if settings.classification_model == 'res':
            pred = self.classification_model(x)
        elif settings.classification_model == 'delf':
            pred, attention_prob = self.classification_model(x)

        loss = self.crit(pred, label)

        _, pred_word = torch.max(pred, 1)
        total = len(label)
        correct = (pred_word == label).sum()
        accuracy = 100 * correct / total
        total = pred.shape[0]

        if return_correct == False:
            return loss, accuracy
        else:
            return loss, accuracy, correct, total
Exemple #12
0
    else:
        text_writer = open(os.path.join(opt.outf, 'train.csv'), 'w')

    vgg_ext = model_big.VggExtractor()
    capnet = model_big.CapsuleNet(4, opt.gpu_id)
    capsule_loss = model_big.CapsuleLoss(opt.gpu_id)

    optimizer = Adam(capnet.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))

    if opt.resume > 0:
        capnet.load_state_dict(
            torch.load(
                os.path.join(opt.outf, 'capsule_' + str(opt.resume) + '.pt')))
        capnet.train(mode=True)
        optimizer.load_state_dict(
            torch.load(
                os.path.join(opt.outf, 'optim_' + str(opt.resume) + '.pt')))

        if opt.gpu_id >= 0:
            for state in optimizer.state.values():
                for k, v in state.items():
                    if isinstance(v, torch.Tensor):
                        state[k] = v.cuda(opt.gpu_id)

    if opt.gpu_id >= 0:
        capnet.cuda(opt.gpu_id)
        vgg_ext.cuda(opt.gpu_id)
        capsule_loss.cuda(opt.gpu_id)

    transform_fwd = transforms.Compose([
        transforms.Resize((opt.imageSize, opt.imageSize)),
Exemple #13
0
def main(args):
    model = load_config(args.model)
    dataset = load_config(args.dataset)

    device = torch.device("cuda" if model["common"]["cuda"] else "cpu")

    if model["common"]["cuda"] and not torch.cuda.is_available():
        sys.exit("Error: CUDA requested but not available")

    os.makedirs(model["common"]["checkpoint"], exist_ok=True)

    num_classes = len(dataset["common"]["classes"])
    net = UNet(num_classes)
    net = DataParallel(net)
    net = net.to(device)

    if model["common"]["cuda"]:
        torch.backends.cudnn.benchmark = True

    try:
        weight = torch.Tensor(dataset["weights"]["values"])
    except KeyError:
        if model["opt"]["loss"] in ("CrossEntropy", "mIoU", "Focal"):
            sys.exit(
                "Error: The loss function used, need dataset weights values")

    optimizer = Adam(net.parameters(), lr=model["opt"]["lr"])

    resume = 0
    if args.checkpoint:

        def map_location(storage, _):
            return storage.cuda() if model["common"]["cuda"] else storage.cpu()

        # https://github.com/pytorch/pytorch/issues/7178
        chkpt = torch.load(args.checkpoint, map_location=map_location)
        net.load_state_dict(chkpt["state_dict"])

        if args.resume:
            optimizer.load_state_dict(chkpt["optimizer"])
            resume = chkpt["epoch"]

    if model["opt"]["loss"] == "CrossEntropy":
        criterion = CrossEntropyLoss2d(weight=weight).to(device)
    elif model["opt"]["loss"] == "mIoU":
        criterion = mIoULoss2d(weight=weight).to(device)
    elif model["opt"]["loss"] == "Focal":
        criterion = FocalLoss2d(weight=weight).to(device)
    elif model["opt"]["loss"] == "Lovasz":
        criterion = LovaszLoss2d().to(device)
    else:
        sys.exit("Error: Unknown [opt][loss] value !")

    train_loader, val_loader = get_dataset_loaders(model, dataset,
                                                   args.workers)

    num_epochs = model["opt"]["epochs"]
    if resume >= num_epochs:
        sys.exit(
            "Error: Epoch {} set in {} already reached by the checkpoint provided"
            .format(num_epochs, args.model))

    history = collections.defaultdict(list)
    log = Log(os.path.join(model["common"]["checkpoint"], "log"))

    log.log("--- Hyper Parameters on Dataset: {} ---".format(
        dataset["common"]["dataset"]))
    log.log("Batch Size:\t {}".format(model["common"]["batch_size"]))
    log.log("Image Size:\t {}".format(model["common"]["image_size"]))
    log.log("Learning Rate:\t {}".format(model["opt"]["lr"]))
    log.log("Loss function:\t {}".format(model["opt"]["loss"]))
    if "weight" in locals():
        log.log("Weights :\t {}".format(dataset["weights"]["values"]))
    log.log("---")

    for epoch in range(resume, num_epochs):
        log.log("Epoch: {}/{}".format(epoch + 1, num_epochs))

        train_hist = train(train_loader, num_classes, device, net, optimizer,
                           criterion)
        log.log(
            "Train    loss: {:.4f}, mIoU: {:.3f}, {} IoU: {:.3f}, MCC: {:.3f}".
            format(
                train_hist["loss"],
                train_hist["miou"],
                dataset["common"]["classes"][1],
                train_hist["fg_iou"],
                train_hist["mcc"],
            ))

        for k, v in train_hist.items():
            history["train " + k].append(v)

        val_hist = validate(val_loader, num_classes, device, net, criterion)
        log.log(
            "Validate loss: {:.4f}, mIoU: {:.3f}, {} IoU: {:.3f}, MCC: {:.3f}".
            format(val_hist["loss"], val_hist["miou"],
                   dataset["common"]["classes"][1], val_hist["fg_iou"],
                   val_hist["mcc"]))

        for k, v in val_hist.items():
            history["val " + k].append(v)

        visual = "history-{:05d}-of-{:05d}.png".format(epoch + 1, num_epochs)
        plot(os.path.join(model["common"]["checkpoint"], visual), history)

        checkpoint = "checkpoint-{:05d}-of-{:05d}.pth".format(
            epoch + 1, num_epochs)

        states = {
            "epoch": epoch + 1,
            "state_dict": net.state_dict(),
            "optimizer": optimizer.state_dict()
        }

        torch.save(states,
                   os.path.join(model["common"]["checkpoint"], checkpoint))
Exemple #14
0
class Agent():
    def __init__(self, state_dim, action_dim, is_training_mode,
                 policy_kl_range, policy_params, value_clip, entropy_coef,
                 vf_loss_coef, batchsize, PPO_epochs, gamma, lam,
                 learning_rate):
        self.policy_kl_range = policy_kl_range
        self.policy_params = policy_params
        self.value_clip = value_clip
        self.entropy_coef = entropy_coef
        self.vf_loss_coef = vf_loss_coef
        self.batchsize = batchsize
        self.PPO_epochs = PPO_epochs
        self.is_training_mode = is_training_mode
        self.action_dim = action_dim
        self.std = torch.ones([1, action_dim]).float().to(device)

        self.policy = Policy_Model(state_dim, action_dim)
        self.policy_old = Policy_Model(state_dim, action_dim)
        self.policy_optimizer = Adam(self.policy.parameters(),
                                     lr=learning_rate)

        self.value = Value_Model(state_dim, action_dim)
        self.value_old = Value_Model(state_dim, action_dim)
        self.value_optimizer = Adam(self.value.parameters(), lr=learning_rate)

        self.policy_memory = PolicyMemory()
        self.policy_loss = TrulyPPO(policy_kl_range, policy_params, value_clip,
                                    vf_loss_coef, entropy_coef, gamma, lam)

        self.aux_memory = AuxMemory()
        self.aux_loss = JointAux()

        self.distributions = Continous()

        if is_training_mode:
            self.policy.train()
            self.value.train()
        else:
            self.policy.eval()
            self.value.eval()

    def save_eps(self, state, action, reward, done, next_state):
        self.policy_memory.save_eps(state, action, reward, done, next_state)

    def act(self, state):
        state = torch.FloatTensor(state).unsqueeze(0).to(device).detach()
        action_mean, _ = self.policy(state)

        # We don't need sample the action in Test Mode
        # only sampling the action in Training Mode in order to exploring the actions
        if self.is_training_mode:
            # Sample the action
            action = self.distributions.sample(action_mean, self.std)
        else:
            action = action_mean

        return action.squeeze(0).cpu().numpy()

    # Get loss and Do backpropagation
    def training_ppo(self, states, actions, rewards, dones, next_states):
        action_mean, _ = self.policy(states)
        values = self.value(states)
        old_action_mean, _ = self.policy_old(states)
        old_values = self.value_old(states)
        next_values = self.value(next_states)

        loss = self.policy_loss.compute_loss(action_mean, self.std,
                                             old_action_mean, self.std, values,
                                             old_values, next_values, actions,
                                             rewards, dones)

        self.policy_optimizer.zero_grad()
        self.value_optimizer.zero_grad()

        loss.backward()

        self.policy_optimizer.step()
        self.value_optimizer.step()

    def training_aux(self, states):
        Returns = self.value(states).detach()

        action_mean, values = self.policy(states)
        old_action_mean, _ = self.policy_old(states)

        joint_loss = self.aux_loss.compute_loss(action_mean, self.std,
                                                old_action_mean, self.std,
                                                values, Returns)

        self.policy_optimizer.zero_grad()
        joint_loss.backward()
        self.policy_optimizer.step()

    # Update the model
    def update_ppo(self):
        dataloader = DataLoader(self.policy_memory,
                                self.batchsize,
                                shuffle=False)

        # Optimize policy for K epochs:
        for _ in range(self.PPO_epochs):
            for states, actions, rewards, dones, next_states in dataloader:
                self.training_ppo(states.float().to(device),
                                  actions.float().to(device),
                                  rewards.float().to(device),
                                  dones.float().to(device),
                                  next_states.float().to(device))

        # Clear the memory
        states, _, _, _, _ = self.policy_memory.get_all()
        self.aux_memory.save_all(states)
        self.policy_memory.clear_memory()

        # Copy new weights into old policy:
        self.policy_old.load_state_dict(self.policy.state_dict())
        self.value_old.load_state_dict(self.value.state_dict())

    def update_aux(self):
        dataloader = DataLoader(self.aux_memory, self.batchsize, shuffle=False)

        # Optimize policy for K epochs:
        for _ in range(self.PPO_epochs):
            for states in dataloader:
                self.training_aux(states.float().to(device))

        # Clear the memory
        self.aux_memory.clear_memory()

        # Copy new weights into old policy:
        self.policy_old.load_state_dict(self.policy.state_dict())

    def save_weights(self):
        torch.save(
            {
                'model_state_dict': self.policy.state_dict(),
                'optimizer_state_dict': self.policy_optimizer.state_dict()
            }, 'SlimeVolley/policy.tar')

        torch.save(
            {
                'model_state_dict': self.value.state_dict(),
                'optimizer_state_dict': self.value_optimizer.state_dict()
            }, 'SlimeVolley/value.tar')

    def load_weights(self):
        policy_checkpoint = torch.load('SlimeVolley/policy.tar')
        self.policy.load_state_dict(policy_checkpoint['model_state_dict'])
        self.policy_optimizer.load_state_dict(
            policy_checkpoint['optimizer_state_dict'])

        value_checkpoint = torch.load('SlimeVolley/value.tar')
        self.value.load_state_dict(value_checkpoint['model_state_dict'])
        self.value_optimizer.load_state_dict(
            value_checkpoint['optimizer_state_dict'])
Exemple #15
0
def train_multiple_epochs(train_dataset,
                          test_dataset,
                          model,
                          epochs,
                          batch_size,
                          lr,
                          lr_decay_factor,
                          lr_decay_step_size,
                          weight_decay,
                          ARR=0,
                          logger=None,
                          continue_from=None,
                          res_dir=None):

    rmses = []

    train_loader = DataLoader(train_dataset,
                              batch_size,
                              shuffle=True,
                              num_workers=mp.cpu_count())
    test_loader = DataLoader(test_dataset,
                             batch_size,
                             shuffle=False,
                             num_workers=mp.cpu_count())

    model.to(device).reset_parameters()
    optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    start_epoch = 1
    if continue_from is not None:
        model.load_state_dict(
            torch.load(
                os.path.join(res_dir,
                             'model_checkpoint{}.pth'.format(continue_from))))
        optimizer.load_state_dict(
            torch.load(
                os.path.join(
                    res_dir,
                    'optimizer_checkpoint{}.pth'.format(continue_from))))
        start_epoch = continue_from + 1

    if torch.cuda.is_available():
        torch.cuda.synchronize()

    t_start = time.perf_counter()
    pbar = tqdm(range(start_epoch, epochs + start_epoch))
    for epoch in pbar:
        train_loss = train(model,
                           optimizer,
                           train_loader,
                           device,
                           regression=True,
                           ARR=ARR)
        rmses.append(eval_rmse(model, test_loader, device))
        eval_info = {
            'epoch': epoch,
            'train_loss': train_loss,
            'test_rmse': rmses[-1],
        }
        pbar.set_description(
            'Epoch {}, train loss {:.6f}, test rmse {:.6f}'.format(
                *eval_info.values()))

        if epoch % lr_decay_step_size == 0:
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr_decay_factor * param_group['lr']

        if logger is not None:
            logger(eval_info, model, optimizer)

    if torch.cuda.is_available():
        torch.cuda.synchronize()

    t_end = time.perf_counter()
    duration = t_end - t_start

    print('Final Test RMSE: {:.6f}, Duration: {:.6f}'.format(
        rmses[-1], duration))

    return rmses[-1]
    def fit(self, model, feature_extraction, protocol, log_dir, subset='train',
            epochs=1000, restart=0, gpu=False):
        """Train model

        Parameters
        ----------
        model : nn.Module
            Embedding model
        feature_extraction :
            Feature extraction.
        protocol : pyannote.database.Protocol
        log_dir : str
            Directory where models and other log files are stored.
        subset : {'train', 'development', 'test'}, optional
            Defaults to 'train'.
        epochs : int, optional
            Train model for that many epochs.
        restart : int, optional
            Restart training at this epoch. Defaults to train from scratch.
        gpu : bool, optional
        """

        import tensorboardX
        writer = tensorboardX.SummaryWriter(log_dir=log_dir)

        checkpoint = Checkpoint(log_dir=log_dir,
                                      restart=restart > 0)

        batch_generator = self.get_batch_generator(feature_extraction)
        batches = batch_generator(protocol, subset=subset)
        batch = next(batches)

        batches_per_epoch = batch_generator.batches_per_epoch

        # save list of classes (one speaker per line)
        labels = batch_generator.labels
        classes_txt = self.CLASSES_TXT.format(log_dir=log_dir)
        with open(classes_txt, mode='w') as fp:
            for label in labels:
                fp.write(f'{label}\n')

        # initialize classifier
        n_classes = batch_generator.n_classes
        classifier = Classifier(model.output_dim, n_classes,
                                linear=self.linear)

        # load precomputed weights in case of restart
        if restart > 0:
            weights_pt = checkpoint.WEIGHTS_PT.format(
                log_dir=log_dir, epoch=restart)
            model.load_state_dict(torch.load(weights_pt))
            classifier_pt = self.CLASSIFIER_PT.format(
                log_dir=log_dir, epoch=restart)

        # send models to GPU
        if gpu:
            model = model.cuda()
            classifier = classifier.cuda(device=None)

        model.internal = False

        optimizer = Adam(list(model.parameters()) + \
                         list(classifier.parameters()))
        if restart > 0:
            optimizer_pt = checkpoint.OPTIMIZER_PT.format(
                log_dir=log_dir, epoch=restart)
            optimizer.load_state_dict(torch.load(optimizer_pt))
            if gpu:
                for state in optimizer.state.values():
                    for k, v in state.items():
                        if torch.is_tensor(v):
                            state[k] = v.cuda()

        epoch = restart if restart > 0 else -1
        while True:
            epoch += 1
            if epoch > epochs:
                break

            loss_avg = 0.

            log_epoch = (epoch < 10) or (epoch % 5 == 0)

            if log_epoch:
                pass

            desc = 'Epoch #{0}'.format(epoch)
            for i in tqdm(range(batches_per_epoch), desc=desc):

                model.zero_grad()

                batch = next(batches)

                X = batch['X']
                y = batch['y']
                if not getattr(model, 'batch_first', True):
                    X = np.rollaxis(X, 0, 2)
                X = np.array(X, dtype=np.float32)
                X = Variable(torch.from_numpy(X))
                y = Variable(torch.from_numpy(y))

                if gpu:
                    X = X.cuda()
                    y = y.cuda()

                fX = model(X)
                y_pred = classifier(fX)

                loss = self.loss_(y_pred, y)

                if log_epoch:
                    pass

                # log loss
                if gpu:
                    loss_ = float(loss.data.cpu().numpy())
                else:
                    loss_ = float(loss.data.numpy())
                loss_avg += loss_

                loss.backward()
                optimizer.step()

            loss_avg /= batches_per_epoch
            writer.add_scalar('train/softmax/loss', loss_avg,
                              global_step=epoch)

            if log_epoch:
                pass

            checkpoint.on_epoch_end(epoch, model, optimizer,
                                    extra={self.CLASSIFIER_PT: classifier})
Exemple #17
0
def train(args, model, enc=False):
    best_acc = 0

    #TODO: calculate weights by processing dataset histogram (now its being set by hand from the torch values)
    #create a loder to run all images and calculate histogram of labels, then create weight array using class balancing

    weight = torch.ones(NUM_CLASSES)
    if (enc):        
        weight[0] = 4.38133159
        weight[1] = 1.29574148
    else:
        weight[0] = 4.40513628
        weight[1] = 1.293674
        
    if (enc):
        up = torch.nn.Upsample(scale_factor=16, mode='bilinear')
    else:
        up = torch.nn.Upsample(scale_factor=2, mode='bilinear')
        
    if args.cuda:
        up = up.cuda()

    assert os.path.exists(args.datadir), "Error: datadir (dataset directory) could not be loaded"

    co_transform = MyCoTransform(enc, augment=True, height=args.height)#1024)
    co_transform_val = MyCoTransform(enc, augment=False, height=args.height)#1024)
    dataset_train = cityscapes(args.datadir, co_transform, 'train')
    dataset_val = cityscapes(args.datadir, co_transform_val, 'val')

    loader = DataLoader(dataset_train, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=True)
    loader_val = DataLoader(dataset_val, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False)

    if args.cuda:
        weight = weight.cuda()
  
    if args.weighted:
        criterion = CrossEntropyLoss2d(weight)
    else:            
        criterion = CrossEntropyLoss2d()
        
    print(type(criterion))

    savedir = args.savedir

    if (enc):
        automated_log_path = savedir + "/automated_log_encoder.txt"
        modeltxtpath = savedir + "/model_encoder.txt"
    else:
        automated_log_path = savedir + "/automated_log.txt"
        modeltxtpath = savedir + "/model.txt"    

    if (not os.path.exists(automated_log_path)):    #dont add first line if it exists 
        with open(automated_log_path, "a") as myfile:
            myfile.write("Epoch\t\tTrain-loss\t\tTest-loss\t\tTrain-IoU\t\tTest-IoU\t\tlearningRate")

    with open(modeltxtpath, "w") as myfile:
        myfile.write(str(model))


    #TODO: reduce memory in first gpu: https://discuss.pytorch.org/t/multi-gpu-training-memory-usage-in-balance/4163/4        #https://github.com/pytorch/pytorch/issues/1893

    #optimizer = Adam(model.parameters(), 5e-4, (0.9, 0.999),  eps=1e-08, weight_decay=2e-4)     ## scheduler 1
    optimizer = Adam(model.parameters(), 5e-4, (0.9, 0.999),  eps=1e-08, weight_decay=1e-4)      ## scheduler 2

    start_epoch = 1
    if args.resume:
        #Must load weights, optimizer, epoch and best value. 
        if enc:
            filenameCheckpoint = savedir + '/checkpoint_enc.pth.tar'
        else:
            filenameCheckpoint = savedir + '/checkpoint.pth.tar'

        assert os.path.exists(filenameCheckpoint), "Error: resume option was used but checkpoint was not found in folder"
        checkpoint = torch.load(filenameCheckpoint)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        best_acc = checkpoint['best_acc']
        print("=> Loaded checkpoint at epoch {})".format(checkpoint['epoch']))

    #scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5) # set up scheduler     ## scheduler 1
    lambda1 = lambda epoch: pow((1-((epoch-1)/args.num_epochs)),0.9)  ## scheduler 2
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)                             ## scheduler 2

    if args.visualize and args.steps_plot > 0:
        board = Dashboard(args.port)

    for epoch in range(start_epoch, args.num_epochs+1):
        print("----- TRAINING - EPOCH", epoch, "-----")

        scheduler.step(epoch)    ## scheduler 2

        epoch_loss = []
        time_train = []
     
        doIouTrain = args.iouTrain   
        doIouVal =  args.iouVal      

        if (doIouTrain):
            iouEvalTrain = iouEval(NUM_CLASSES, args.ignoreindex)

        usedLr = 0
        for param_group in optimizer.param_groups:
            print("LEARNING RATE: ", param_group['lr'])
            usedLr = float(param_group['lr'])

        model.train()
        for step, (images, labels, images_orig, labels_orig) in enumerate(loader):

            start_time = time.time()
            #print (labels.size())
            #print (np.unique(labels.numpy()))
            #print("labels: ", np.unique(labels[0].numpy()))
            #labels = torch.ones(4, 1, 512, 1024).long()
            if args.cuda:
                images = images.cuda()
                labels = labels.cuda()

            inputs = Variable(images)
            targets = Variable(labels)
            outputs = model(inputs, only_encode=enc)

            #print("targets", np.unique(targets[:, 0].cpu().data.numpy()))

            optimizer.zero_grad()
            loss = criterion(outputs, targets[:, 0])
            loss.backward()
            optimizer.step()

            epoch_loss.append(loss.data[0])
            time_train.append(time.time() - start_time)

            if (doIouTrain):
                #start_time_iou = time.time()
                upsampledOutputs = up(outputs)
                iouEvalTrain.addBatch(upsampledOutputs.max(1)[1].unsqueeze(1).data, labels_orig)
                #print ("Time to add confusion matrix: ", time.time() - start_time_iou)      

            #print(outputs.size())
            if args.visualize and args.steps_plot > 0 and step % args.steps_plot == 0:
                start_time_plot = time.time()
                image = inputs[0].cpu().data
                #image[0] = image[0] * .229 + .485
                #image[1] = image[1] * .224 + .456
                #image[2] = image[2] * .225 + .406
                #print("output", np.unique(outputs[0].cpu().max(0)[1].data.numpy()))
                board.image(image, f'input (epoch: {epoch}, step: {step})')
                if isinstance(outputs, list):   #merge gpu tensors
                    board.image(color_transform(outputs[0][0].cpu().max(0)[1].data.unsqueeze(0)),
                    f'output (epoch: {epoch}, step: {step})')
                else:
                    board.image(color_transform(outputs[0].cpu().max(0)[1].data.unsqueeze(0)),
                    f'output (epoch: {epoch}, step: {step})')
                board.image(color_transform(targets[0].cpu().data),
                    f'target (epoch: {epoch}, step: {step})')
                print ("Time to paint images: ", time.time() - start_time_plot)
            if args.steps_loss > 0 and step % args.steps_loss == 0:
                average = sum(epoch_loss) / len(epoch_loss)
                print(f'loss: {average:0.4} (epoch: {epoch}, step: {step})', 
                        "// Avg time/img: %.4f s" % (sum(time_train) / len(time_train) / args.batch_size))

            
        average_epoch_loss_train = sum(epoch_loss) / len(epoch_loss)
        
        iouTrain = 0
        if (doIouTrain):
            iouTrain, iou_classes = iouEvalTrain.getIoU()
            iouStr = getColorEntry(iouTrain)+'{:0.2f}'.format(iouTrain*100) + '\033[0m'
            print ("EPOCH IoU on TRAIN set: ", iouStr, "%", iou_classes)  

        #Validate on 500 val images after each epoch of training
        print("----- VALIDATING - EPOCH", epoch, "-----")
        model.eval()
        epoch_loss_val = []
        time_val = []

        if (doIouVal):
            iouEvalVal = iouEval(NUM_CLASSES, args.ignoreindex)

        for step, (images, labels, images_orig, labels_orig) in enumerate(loader_val):
            start_time = time.time()
            if args.cuda:
                images = images.cuda()
                labels = labels.cuda()

            inputs = Variable(images, volatile=True)    #volatile flag makes it free backward or outputs for eval
            targets = Variable(labels, volatile=True)
            outputs = model(inputs, only_encode=enc) 

            loss = criterion(outputs, targets[:, 0])
            epoch_loss_val.append(loss.data[0])
            time_val.append(time.time() - start_time)


            #Add batch to calculate TP, FP and FN for iou estimation
            if (doIouVal):
                #start_time_iou = time.time()
                upsampledOutputs = up(outputs)
                iouEvalVal.addBatch(upsampledOutputs.max(1)[1].unsqueeze(1).data, labels_orig)
                #print ("Time to add confusion matrix: ", time.time() - start_time_iou)

            if args.visualize and args.steps_plot > 0 and step % args.steps_plot == 0:
                start_time_plot = time.time()
                image = inputs[0].cpu().data
                board.image(image, f'VAL input (epoch: {epoch}, step: {step})')
                if isinstance(outputs, list):   #merge gpu tensors
                    board.image(color_transform(outputs[0][0].cpu().max(0)[1].data.unsqueeze(0)),
                    f'VAL output (epoch: {epoch}, step: {step})')
                else:
                    board.image(color_transform(outputs[0].cpu().max(0)[1].data.unsqueeze(0)),
                    f'VAL output (epoch: {epoch}, step: {step})')
                board.image(color_transform(targets[0].cpu().data),
                    f'VAL target (epoch: {epoch}, step: {step})')
                print ("Time to paint images: ", time.time() - start_time_plot)
            if args.steps_loss > 0 and step % args.steps_loss == 0:
                average = sum(epoch_loss_val) / len(epoch_loss_val)
                print(f'VAL loss: {average:0.4} (epoch: {epoch}, step: {step})', 
                        "// Avg time/img: %.4f s" % (sum(time_val) / len(time_val) / args.batch_size))
                       

        average_epoch_loss_val = sum(epoch_loss_val) / len(epoch_loss_val)
        #scheduler.step(average_epoch_loss_val, epoch)  ## scheduler 1   # update lr if needed

        iouVal = 0
        if (doIouVal):
            iouVal, iou_classes = iouEvalVal.getIoU()
            iouStr = getColorEntry(iouVal)+'{:0.2f}'.format(iouVal*100) + '\033[0m'
            print ("EPOCH IoU on VAL set: ", iouStr, "%", iou_classes) 
           

        # remember best valIoU and save checkpoint
        if iouVal == 0:
            current_acc = -average_epoch_loss_val
        else:
            current_acc = iouVal 
        is_best = current_acc > best_acc
        best_acc = max(current_acc, best_acc)
        if enc:
            filenameCheckpoint = savedir + '/checkpoint_enc.pth.tar'
            filenameBest = savedir + '/model_best_enc.pth.tar'    
        else:
            filenameCheckpoint = savedir + '/checkpoint.pth.tar'
            filenameBest = savedir + '/model_best.pth.tar'
        save_checkpoint({
            'epoch': epoch + 1,
            'arch': str(model),
            'state_dict': model.state_dict(),
            'best_acc': best_acc,
            'optimizer' : optimizer.state_dict(),
        }, is_best, filenameCheckpoint, filenameBest)

        #SAVE MODEL AFTER EPOCH
        if (enc):
            filename = f'{savedir}/model_encoder-{epoch:03}.pth'
            filenamebest = f'{savedir}/model_encoder_best.pth'
        else:
            filename = f'{savedir}/model-{epoch:03}.pth'
            filenamebest = f'{savedir}/model_best.pth'
        if args.epochs_save > 0 and step > 0 and step % args.epochs_save == 0:
            torch.save(model.state_dict(), filename)
            print(f'save: {filename} (epoch: {epoch})')
        if (is_best):
            torch.save(model.state_dict(), filenamebest)
            print(f'save: {filenamebest} (epoch: {epoch})')
            if (not enc):
                with open(savedir + "/best.txt", "w") as myfile:
                    myfile.write("Best epoch is %d, with Val-IoU= %.4f" % (epoch, iouVal))   
            else:
                with open(savedir + "/best_encoder.txt", "w") as myfile:
                    myfile.write("Best epoch is %d, with Val-IoU= %.4f" % (epoch, iouVal))           

        #SAVE TO FILE A ROW WITH THE EPOCH RESULT (train loss, val loss, train IoU, val IoU)
        #Epoch		Train-loss		Test-loss	Train-IoU	Test-IoU		learningRate
        with open(automated_log_path, "a") as myfile:
            myfile.write("\n%d\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.8f" % (epoch, average_epoch_loss_train, average_epoch_loss_val, iouTrain, iouVal, usedLr ))
    
    return(model)   #return model (convenience for encoder-decoder training)
Exemple #18
0
class DDPGAgent(object):
    """
    General class for DDPG agents
    (policy, critic, target policy, target critic, exploration noise)
    """
    def __init__(self,
                 num_in_pol,
                 num_out_pol,
                 num_in_critic,
                 hidden_dim,
                 lr,
                 lr_critic_coef,
                 use_discrete_action,
                 weight_decay,
                 discrete_exploration_scheme,
                 boltzmann_temperature,
                 logger=None):
        """
        Inputs:
            num_in_pol (int): number of dimensions for policy input
            num_out_pol (int): number of dimensions for policy output
            num_in_critic (int): number of dimensions for critic input
        """
        # Instantiate the models
        self.policy = MLPNetwork(num_in_pol,
                                 num_out_pol,
                                 hidden_dim=hidden_dim,
                                 out_fn='tanh',
                                 use_discrete_action=use_discrete_action,
                                 name="policy",
                                 logger=logger)
        self.critic = MLPNetwork(num_in_critic,
                                 1,
                                 hidden_dim=hidden_dim,
                                 out_fn='linear',
                                 use_discrete_action=use_discrete_action,
                                 name="critic",
                                 logger=logger)

        with torch.no_grad():
            self.target_policy = MLPNetwork(
                num_in_pol,
                num_out_pol,
                hidden_dim=hidden_dim,
                out_fn='tanh',
                use_discrete_action=use_discrete_action,
                name="target_policy",
                logger=logger)
            self.target_critic = MLPNetwork(
                num_in_critic,
                1,
                hidden_dim=hidden_dim,
                out_fn='linear',
                use_discrete_action=use_discrete_action,
                name="target_critic",
                logger=logger)

        hard_update(self.target_policy, self.policy)
        hard_update(self.target_critic, self.critic)

        # Instantiate the optimizers
        self.policy_optimizer = Adam(self.policy.parameters(), lr=lr)
        self.critic_optimizer = Adam(self.critic.parameters(),
                                     lr=lr_critic_coef * lr,
                                     weight_decay=weight_decay)

        # Sets noise
        if not use_discrete_action:
            self.exploration = OUNoise(num_out_pol)
        else:
            self.exploration = None  # epsilon for eps-greedy
        self.use_discrete_action = use_discrete_action
        self.discrete_exploration_scheme = discrete_exploration_scheme
        self.boltzmann_temperature = boltzmann_temperature

    def reset_noise(self):
        if not self.use_discrete_action:
            self.exploration.reset()

    def scale_noise(self, scale):
        if self.use_discrete_action:
            self.exploration = scale
        else:
            self.exploration.scale = scale

    def select_action(self, obs, is_exploring=False):
        """
        Take a step forward in environment for a minibatch of observations
        Inputs:
            obs (PyTorch Variable): Observations for this agent
            is_exploring (boolean): Whether or not to add exploration noise
        Outputs:
            action (PyTorch Variable): Actions for this agent
        """
        action = self.policy(obs)
        if self.use_discrete_action:
            if is_exploring:
                if self.discrete_exploration_scheme == 'e-greedy':
                    action = onehot_from_logits(action, eps=self.exploration)
                elif self.discrete_exploration_scheme == 'boltzmann':
                    action = gumbel_softmax(action /
                                            self.boltzmann_temperature,
                                            hard=True)
                else:
                    raise NotImplementedError
            else:
                action = onehot_from_logits(action, eps=0.)
        else:  # continuous action
            if is_exploring:
                action += Variable(Tensor(self.exploration.noise()),
                                   requires_grad=False)
            action = action.clamp(-1., 1.)
        return action

    def get_params(self):
        return {
            'policy': self.policy.state_dict(),
            'critic': self.critic.state_dict(),
            'target_policy': self.target_policy.state_dict(),
            'target_critic': self.target_critic.state_dict(),
            'policy_optimizer': self.policy_optimizer.state_dict(),
            'critic_optimizer': self.critic_optimizer.state_dict()
        }

    def load_params(self, params):
        self.policy.load_state_dict(params['policy'])
        self.critic.load_state_dict(params['critic'])
        self.target_policy.load_state_dict(params['target_policy'])
        self.target_critic.load_state_dict(params['target_critic'])
        self.policy_optimizer.load_state_dict(params['policy_optimizer'])
        self.critic_optimizer.load_state_dict(params['critic_optimizer'])
Exemple #19
0
def train(args):
    writer = SummaryWriter()
    logger = make_logger(args.log_file)

    if args.zs:
        packed = args.packed_pkl_zs
    else:
        packed = args.packed_pkl_nozs

    data = ZSIH_dataloader(args.sketch_dir, args.image_dir, args.stats_file, args.embedding_file, packed, zs=args.zs)
    print(len(data))
    dataloader_train = DataLoader(dataset=data, num_workers=args.num_worker, \
                                  batch_size=args.batch_size,
                                  shuffle=args.shuffle)

    logger.info('Building the model ...')
    model = ZSIM(args.hidden_size, args.hashing_bit, args.semantics_size, data.pretrain_embedding.float(), 
                 adj_scaler=args.adj_scaler, dropout=args.dropout, fix_cnn=args.fix_cnn, 
                 fix_embedding=args.fix_embedding, logger=logger)
    logger.info('Building the optimizer ...')
    optimizer = Adam(params=model.parameters(), lr=args.lr)
    #optimizer = SGD(params=model.parameters(), lr=args.lr, momentum=0.9)
    l1_regularization = _Regularization(model, 1, p=1, logger=logger)
    l2_regularization = _Regularization(model, 0.005, p=2, logger=logger)

    if args.start_from is not None:
        logger.info('Loading pretrained model from {} ...'.format(args.start_from))
        ckpt = torch.load(args.start_from, map_location='cpu')
        model.load_state_dict(ckpt['model'])
        optimizer.load_state_dict(ckpt['optimizer'])
    if args.gpu_id != -1:
        model.cuda(args.gpu_id)

    batch_acm = 0
    global_step = 0
    loss_p_xz_acm, loss_q_zx_acm, loss_image_l2_acm, loss_sketch_l2_acm, loss_reg_l2_acm, loss_reg_l1_acm = 0., 0., 0., 0., 0., 0.,
    best_precision = 0.
    best_iter = 0
    patience = args.patience
    logger.info('Hyper-Parameter:')
    logger.info(args)
    logger.info('Model Structure:')
    logger.info(model)
    logger.info('Begin Training !')
    while True:
        if patience <= 0:
            break
        for sketch_batch, image_batch, semantics_batch in dataloader_train:
            if global_step % args.print_every == 0 % args.print_every and global_step and batch_acm % args.cum_num == 0:
                logger.info('Iter {}, Loss/p_xz {:.3f}, Loss/q_zx {:.3f}, Loss/image_l2 {:.3f}, Loss/sketch_l2 {:.3f}, Loss/reg_l2 {:.3f}, Loss/reg_l1 {:.3f}'.format(global_step, \
                             loss_p_xz_acm/args.print_every/args.cum_num, \
                             loss_q_zx_acm/args.print_every/args.cum_num, \
                             loss_image_l2_acm/args.print_every/args.cum_num, \
                             loss_sketch_l2_acm/args.print_every/args.cum_num, \
                             loss_reg_l2_acm/args.print_every/args.cum_num, \
                             loss_reg_l1_acm/args.print_every/args.cum_num))
                loss_p_xz_acm, loss_q_zx_acm, loss_image_l2_acm, loss_sketch_l2_acm, loss_reg_l2_acm, loss_reg_l1_acm = 0., 0., 0., 0., 0., 0.,

            if global_step % args.save_every == 0 % args.save_every and batch_acm % args.cum_num == 0 and global_step :
                if not os.path.exists(args.save_dir):
                    os.mkdir(args.save_dir)
                torch.save({'args':args, 'model':model.state_dict(), \
                        'optimizer':optimizer.state_dict()},
                        '{}/Iter_{}.pkl'.format(args.save_dir,global_step))

                ### Evaluation
                model.eval()

                image_label = list()
                image_feature = list()
                for image, label in data.load_test_images(batch_size=args.batch_size):
                    image = image.cuda(args.gpu_id)
                    image_label += label
                    tmp_feature = model.hash(image, 1).cpu().detach().numpy()
                    image_feature.append(tmp_feature)
                image_feature = np.vstack(image_feature)

                sketch_label = list()
                sketch_feature = list()
                for sketch, label in data.load_test_sketch(batch_size=args.batch_size):
                    sketch = sketch.cuda(args.gpu_id)
                    sketch_label += label
                    tmp_feature = model.hash(sketch, 0).cpu().detach().numpy()
                    sketch_feature.append(tmp_feature)
                sketch_feature = np.vstack(sketch_feature)

                dists_cosine = cdist(image_feature, sketch_feature, 'hamming')

                rank_cosine = np.argsort(dists_cosine, 0)

                for n in [5, 100, 200]:
                    ranksn_cosine = rank_cosine[:n, :].T

                    classesn_cosine = np.array([[image_label[i] == sketch_label[r] \
                                                for i in ranksn_cosine[r]] for r in range(len(ranksn_cosine))])

                    precision_cosine = np.mean(classesn_cosine)

                    writer.add_scalar('Precision_{}/cosine'.format(n),
                            precision_cosine, global_step)

                    logger.info('Iter {}, Precision_{}/cosine {}'.format(global_step, n, precision_cosine))

                if best_precision < precision_cosine:
                    patience = args.patience
                    best_precision = precision_cosine
                    best_iter = global_step
                    writer.add_scalar('Best/Precision_200', best_precision, best_iter)
                    logger.info('Iter {}, Best Precision_200 {}'.format(global_step, best_precision))
                    torch.save({'args':args, 'model':model.state_dict(), \
                        'optimizer':optimizer.state_dict()}, '{}/Best.pkl'.format(args.save_dir))
                else:
                    patience -= 1
            if patience <= 0:
                break

            model.train()
            batch_acm += 1
            if global_step <= args.warmup_steps:
                update_lr(optimizer, args.lr*global_step/args.warmup_steps)
            """
            #code for testing if the images and the sketches are corresponding to each other correctly

            for i in range(args.batch_size):
                sk = sketch_batch[i].numpy().reshape(224, 224, 3)
                im = image_batch[i].numpy().reshape(224, 224, 3)
                print(label[i])
                ims = np.vstack((np.uint8(sk), np.uint8(im)))
                cv2.imshow('test', ims)
                cv2.waitKey(3000)
            """

            sketch = sketch_batch.cuda(args.gpu_id)
            image = image_batch.cuda(args.gpu_id)
            semantics = semantics_batch.long().cuda(args.gpu_id)

            optimizer.zero_grad()
            loss = model(sketch, image, semantics)
            loss_l1 = l1_regularization()
            loss_l2 = l2_regularization()
            loss_p_xz_acm += loss['p_xz'][0].item()
            loss_q_zx_acm += loss['q_zx'][0].item()
            loss_image_l2_acm += loss['image_l2'][0].item()
            loss_sketch_l2_acm += loss['sketch_l2'][0].item()
            loss_reg_l1_acm += loss_l1.item()
            loss_reg_l2_acm += (loss_l2.item() / 0.005)
            writer.add_scalar('Loss/p_xz', loss['p_xz'][0].item(), global_step)
            writer.add_scalar('Loss/q_zx', loss['q_zx'][0].item(), global_step)
            writer.add_scalar('Loss/image_l2', loss['image_l2'][0].item(), global_step)
            writer.add_scalar('Loss/sketch_l2', loss['sketch_l2'][0].item(), global_step)
            writer.add_scalar('Loss/reg_l2', (loss_l2.item() / 0.005), global_step)
            writer.add_scalar('Loss/reg_l1', loss_l1.item(), global_step)
            
            loss_ = loss_l2
            for item in loss.values():
                loss_ += item[0]*item[1]
            loss_.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            if batch_acm % args.cum_num == 0:
                optimizer.step()
                global_step += 1
    def run_training(self):
        if self.config["general"]["restart"]:
            mod_ckpt, op_ckpt = self._load_ckpt("reg_ckpt")
            reg_ckpt, op_reg = self._load_ckpt("regressor")

            # remove gamma from optimizer param groups and set self.gamma
            for g in op_ckpt["param_groups"]:
                self.gamma = g["gamma"]
                del g["gamma"]

        else:
            mod_ckpt = op_ckpt = None
            reg_ckpt = op_reg = None

        dataset, transforms = get_dataset(self.config["data"])
        train_dataset = dataset(transforms,
                                data_keys=self.data_keys,
                                mode="train",
                                debug=self.config["general"]["debug"],
                                **self.config["data"],
                                **self.config["training"])
        # update datakeys, if new ones have been added
        self.data_keys = train_dataset.datakeys
        print(f"Length of train dataset is {len(train_dataset)}")

        # compute sampling distribution
        if self.config["data"]["sampling"] == "full":
            area_distribution = parallel_data_prefetch(
                partial(get_area_sampling_dist, kp_subset=None),
                train_dataset.datadict["keypoints"],
                self.config["data"]["n_data_workers"],
            )
            sampling_distribution = area_distribution / np.sum(
                area_distribution)
        elif self.config["data"]["sampling"] == "body":
            area_distribution = parallel_data_prefetch(
                partial(
                    get_area_sampling_dist,
                    kp_subset=train_dataset.joint_model.body,
                ),
                train_dataset.datadict["keypoints"],
                self.config["data"]["n_data_workers"],
            )
            sampling_distribution = area_distribution / np.sum(
                area_distribution)

        elif self.config["data"]["sampling"] == "pid":
            upids, counts = np.unique(train_dataset.datadict["p_ids"],
                                      return_counts=True)
            sampling_distribution = np.zeros_like(
                train_dataset.datadict["p_ids"], dtype=np.float)
            for pid, n in zip(upids, counts):
                sampling_distribution[train_dataset.datadict["p_ids"] ==
                                      pid] = (1.0 / n)

            assert np.all(sampling_distribution > 0.0)
            sampling_distribution = sampling_distribution / np.sum(
                sampling_distribution)

        else:
            sampling_distribution = None
        sampler = PerPersonSampler(train_dataset,
                                   sampling_dist=sampling_distribution)
        train_loader = DataLoader(
            train_dataset,
            self.config["training"]["batch_size"],
            sampler=sampler,
            drop_last=True,
            num_workers=0 if self.config['general']['debug'] else
            self.config["data"]["n_data_workers"],
        )
        test_dataset = dataset(transforms,
                               data_keys=self.data_keys,
                               mode="test",
                               debug=self.config["general"]["debug"],
                               **self.config["data"],
                               **self.config["training"])
        print(f"Length of test dataset is {len(test_dataset)}")
        test_loader = DataLoader(
            test_dataset,
            batch_size=self.config["logging"]["n_test_samples"],
            shuffle=True,
            drop_last=True,
            num_workers=0 if self.config['general']['debug'] else
            self.config["data"]["n_data_workers"])
        self.test_iterator = iter(test_loader)

        met_loader = DataLoader(
            test_dataset,
            batch_size=self.config["metrics"]["test_batch_size"],
            shuffle=True,
            drop_last=False,
            num_workers=0 if self.config['general']['debug'] else
            self.config["data"]["n_data_workers"])

        n_channels_x = (3 * len(train_dataset.joint_model.norm_T)
                        if self.config["data"]["inplane_normalize"] else 3)
        if self.config["architecture"]["conv_layer_type"] == "l2":

            def init_fn():
                return self.global_step <= self.config["training"][
                    "n_init_batches"]

            vunet = VunetAlter(init_fn=init_fn,
                               n_channels_x=n_channels_x,
                               **self.config["architecture"],
                               **self.config["data"])
        else:
            vunet = VunetAlter(n_channels_x=n_channels_x,
                               **self.config["architecture"],
                               **self.config["data"])

        if mod_ckpt is not None and not self.config["general"]["debug"]:
            vunet.load_state_dict(mod_ckpt)

        if self.parallel:
            vunet = nn.DataParallel(vunet, device_ids=self.all_devices)

        vunet.to(self.all_devices[0])
        n_trainable_params = n_parameters(vunet)
        print(f"Number of trainable params is {n_trainable_params}")
        ssize = self.config["data"]["spatial_size"]
        print(f"Spatial size is {ssize}")

        vgg = vgg19(pretrained=True)
        if self.parallel:
            vgg = nn.DataParallel(vgg, device_ids=self.all_devices)

        vgg.to(self.all_devices[0])
        vgg.eval()

        custom_vgg = PerceptualVGG(vgg, self.config["training"]["vgg_weights"])
        if self.parallel:
            custom_vgg = nn.DataParallel(custom_vgg,
                                         device_ids=self.all_devices)

        custom_vgg.to(self.all_devices[0])

        optimizer = Adam(
            [
                {
                    "params": get_member(vunet, "eu").parameters(),
                    "name": "eu"
                },
                {
                    "params": get_member(vunet, "ed").parameters(),
                    "name": "ed"
                },
                {
                    "params": get_member(vunet, "du").parameters(),
                    "name": "du"
                },
                {
                    "params": get_member(vunet, "dd").parameters(),
                    "name": "dd"
                },
            ],
            lr=self.config["training"]["lr"],
            betas=self.config["training"]["adam_betas"],
        )

        if op_ckpt is not None and not self.config["general"]["debug"]:
            optimizer.load_state_dict(op_ckpt)
            # note this may not work for different optimizers
            start_it = list(
                optimizer.state_dict()["state"].values())[-1]["step"]
        else:
            start_it = 0

        latent_widths = [
            self.config["data"]["spatial_size"] // (2**(vunet.n_scales - i))
            for i in range(self.config["architecture"]["n_latent_scales"], 0,
                           -1)
        ]
        regressor = Regressor(len(test_dataset.joint_model.kps_to_use) * 2,
                              latent_widths=latent_widths,
                              **self.config["architecture"])

        if mod_ckpt is not None and not self.config["general"]["debug"]:
            regressor.load_state_dict(reg_ckpt)
        regressor = regressor.to(self.all_devices[0])
        optimizer_regressor = Adam(regressor.parameters(), lr=0.001)
        if op_reg is not None and not self.config["general"]["debug"]:
            optimizer_regressor.load_state_dict(op_reg)

        print(
            "Number of parameters in regressor",
            sum(p.numel() for p in regressor.parameters()),
        )

        self.global_step = start_it

        if self.config["training"]["end_iteration"] <= start_it:
            raise ValueError(
                "The start iteration is higher or equal than the end iteration. If you want to resume training, adapt end iteration"
            )

        if self.config["general"]["debug"]:
            n_epoch = 1
            n_overall_epoch = 1
        else:
            n_epoch = int(
                ceil(
                    float(self.config["training"]["end_iteration"] - start_it)
                    * self.config["training"]["batch_size"] /
                    len(train_dataset)))
            n_overall_epoch = int(
                ceil(
                    float(self.config["training"]["end_iteration"]) *
                    self.config["training"]["batch_size"] /
                    len(train_dataset)))

        print(f"Starting training for {n_epoch} Epochs!")

        total_steps = n_overall_epoch * len(
            train_dataset) // self.config["training"]["batch_size"]

        print(f"Overall {total_steps} train steps to take...")

        adjust_lr = partial(
            linear_var,
            start_it=0,
            end_it=total_steps,
            start_val=self.config["training"]["lr"],
            end_val=0,
            clip_min=0,
            clip_max=self.config["training"]["lr"],
        )
        if self.config["training"]["imax_scaling"] == "ascend":
            # this is as in beta vae
            start_val_imax = 0
            end_val_imax = self.config["training"]["information_max"]
        elif self.config["training"]["imax_scaling"] == "descend":
            start_val_imax = self.config["training"]["information_max"]
            end_val_imax = 0
        else:
            start_val_imax = end_val_imax = self.config["training"][
                "information_max"]

        adjust_imax = partial(
            linear_var,
            start_it=0,
            end_it=total_steps,
            start_val=start_val_imax,
            end_val=end_val_imax,
            clip_min=min(start_val_imax, end_val_imax),
            clip_max=max(start_val_imax, end_val_imax),
        )

        self.lr = adjust_lr(start_it)
        self.imax = adjust_imax(start_it)
        print(
            f"Learning rate after adjusting it for the first time is {self.lr}"
        )
        print(
            f"Informmation max after adjusting it for the first time is {self.imax}"
        )

        ish36m = isinstance(train_dataset, Human36mDataset)

        self.kl_avg = []
        self.kl_mean = 0
        self.kl_it = 0

        for pg in optimizer.param_groups:
            pg["lr"] = self.lr

        def train_fn(engine, batch):
            vunet.train()

            self.global_step = engine.state.iteration

            if self.parallel:
                imgs = {name: batch[name] for name in train_dataset.datakeys}
            else:
                imgs = {
                    name: batch[name].to(self.device)
                    for name in train_dataset.datakeys
                }
                # app_img = imgs["app_img"]
            target_img = imgs["pose_img"]
            shape_img = imgs["stickman"]
            # keypoints correspond to pose_img
            if not ish36m:
                pose_img = imgs["pose_img_inplane"]
            else:
                pose_img = target_img

            # apply vunet
            with torch.enable_grad():
                out_img, means, logstds, activations = vunet(
                    pose_img, shape_img)

            # if self.parallel and weights is not None:
            #     weights = weights.to(self.all_devices[-1])
            likelihood_loss_dict = vgg_loss(
                custom_vgg,
                target_img,
                out_img,
            )
            likelihoods = torch.stack(
                [likelihood_loss_dict[key] for key in likelihood_loss_dict],
                dim=0,
            )
            likelihood_loss = self.config["training"]["ll_weight"] * torch.sum(
                likelihoods)
            kl_loss = compute_kl_with_prior(means, logstds)

            loss = likelihood_loss
            tuning = (torch.tensor(1.0, device=self.device)
                      if self.config["architecture"]["cvae"] else torch.tensor(
                          self.gamma, device=self.device))
            if engine.state.iteration > self.config["training"][
                    "n_init_batches"]:
                loss = (loss + tuning * kl_loss).to(self.device)

            if self.config["training"]["train_regressor"]:
                reg_imgs = imgs["reg_imgs"]
                reg_targets = imgs["reg_targets"]

                for i in range(reg_imgs.shape[1]):
                    with torch.no_grad():
                        _, means, _, _ = vunet.ed(vunet.eu(reg_imgs[:, i]))

                    preds = regressor(means)
                    tgts = reg_targets[:,
                                       i].reshape(reg_targets[:, i].shape[0],
                                                  -1)
                    loss_regressor = torch.norm((preds - tgts), dim=1).mean()
                    optimizer_regressor.zero_grad()
                    loss_regressor.backward(retain_graph=True)
                    optimizer_regressor.step()
                # loss_regressor = torch.exp(-4*torch.clamp(loss_regressor, max=0.08))
                loss_reg_model = torch.clamp(loss_regressor, max=1.2)
                loss -= loss_reg_model * self.config["training"][
                    "weight_regressor"]

            log_lr = torch.tensor(self.lr)

            optimizer.zero_grad()
            loss.backward()
            # optimize
            optimizer.step()

            # fixme this won't work with loaded checkpoint
            # keep moving average
            if (engine.state.iteration - start_it) < 100:
                self.kl_avg.append(kl_loss.detach().cpu().numpy())
            else:
                self.kl_avg.pop(0)
                self.kl_avg.append(kl_loss.detach().cpu().numpy())

            self.gamma = self.__update_gamma(kl_loss)

            output_dict = {
                "loss": loss.item(),
                "likelihood_loss": likelihood_loss.item(),
                "kl_loss": kl_loss.item(),
                "learning_rate": log_lr,
                "gamma": self.gamma,
                "imax": self.imax,
                "kl_mean": self.kl_mean,
            }
            if self.config["training"]["train_regressor"]:
                output_dict.update({
                    "loss_reg":
                    loss_regressor.item(),
                    "weight_regressor":
                    self.config["training"]["weight_regressor"],
                })

            likelihood_loss_dict = {
                key: likelihood_loss_dict[key].item()
                for key in likelihood_loss_dict
            }
            output_dict.update(likelihood_loss_dict)
            return output_dict

        trainer = Engine(train_fn)

        # checkpointing
        ckpt_handler = ModelCheckpoint(self.dirs["ckpt"],
                                       "reg_ckpt",
                                       n_saved=10,
                                       require_empty=False)
        save_dict = {
            "model": vunet.module if self.parallel else vunet,
            "optimizer": optimizer,
        }
        trainer.add_event_handler(
            Events.ITERATION_COMPLETED(
                every=self.config["logging"]["ckpt_steps"]),
            ckpt_handler,
            save_dict,
        )
        ckpt_reg = ModelCheckpoint(self.dirs["ckpt"],
                                   "regressor",
                                   n_saved=10,
                                   require_empty=False)
        reg_sdict = save_dict = {
            "model": regressor,
            "optimizer": optimizer_regressor,
        }
        trainer.add_event_handler(
            Events.ITERATION_COMPLETED(
                every=self.config["logging"]["ckpt_steps"]),
            ckpt_reg,
            reg_sdict,
        )

        # tensorboard logs
        tb_dir = self.dirs["log"]
        writer = add_summary_writer(vunet, tb_dir)

        @trainer.on(Events.ITERATION_COMPLETED)
        def adjust_params(engine):
            # update learning rate and kl_loss weight
            it = engine.state.iteration

            self.lr = adjust_lr(it)
            self.imax = adjust_imax(it)
            for pg in optimizer.param_groups:
                pg["lr"] = self.lr
                if "gamma" in pg:
                    pg["gamma"] = self.gamma
                else:
                    pg.update({"gamma": self.gamma})

        pbar = ProgressBar(ascii=True)
        pbar.attach(trainer, output_transform=lambda x: x)

        @trainer.on(
            Events.ITERATION_COMPLETED(
                every=self.config["logging"]["log_steps"]))
        def log(engine):

            wandb.log({"iteration": engine.state.iteration})

            it = engine.state.iteration
            data = engine.state.batch
            if self.parallel:
                imgs = {
                    name: data[name]
                    for name in data if name != "sample_ids"
                }
            else:
                imgs = {
                    name: data[name].to(self.device)
                    for name in data if name not in ["sample_ids"]
                }
            app_img = imgs["app_img"]
            shape_img = imgs["stickman"]
            target_img = imgs["pose_img"]
            if not isinstance(train_dataset, Human36mDataset):
                pose_img = imgs["pose_img_inplane"]

            else:
                pose_img = target_img

            vunet.eval()
            # visualize current performance on train set
            with torch.no_grad():
                # here's the difference to vunet: use NO app img but target image --> no person id necessary
                out_img, _, _, _ = vunet(pose_img, shape_img)

                out_img = scale_img(out_img)
                shape_img = scale_img(shape_img)
                target_img = scale_img(target_img)
                pose_img = scale_img(pose_img)

                writer.add_images(
                    "appearance_images",
                    target_img if self.config["data"]["inplane_normalize"] else
                    pose_img[:self.config["logging"]["n_logged_img"]],
                    it,
                )
                writer.add_images(
                    "shape_images",
                    shape_img[:self.config["logging"]["n_logged_img"]], it)
                writer.add_images(
                    "target_images",
                    target_img[:self.config["logging"]["n_logged_img"]], it)
                writer.add_images(
                    "transferred_images",
                    out_img[:self.config["logging"]["n_logged_img"]], it)

                [
                    writer.add_scalar(key, val, it)
                    for key, val in engine.state.output.items()
                ]

            # test
            try:
                batch = next(self.test_iterator)
            except StopIteration:
                self.test_iterator = iter(test_loader)
                batch = next(self.test_iterator)

            if self.parallel:
                imgs = {name: batch[name] for name in self.data_keys}
            else:
                imgs = {
                    name: batch[name].to(self.device)
                    for name in self.data_keys
                }

            app_img = imgs["app_img"]
            timg = imgs["pose_img"]
            shape_img = imgs["stickman"]
            if self.config["data"]["inplane_normalize"]:
                pose_img = imgs["pose_img_inplane"]
                pose_ids = np.squeeze(
                    imgs["sample_ids"].cpu().numpy()).tolist()
                app_img_disp = test_dataset._get_app_img(
                    ids=pose_ids, inplane_norm=False).squeeze()
            else:
                pose_img = timg
                app_img_disp = app_img

            with torch.no_grad():
                # test reconstruction
                rec_img, _, _, _ = vunet(pose_img, shape_img)
                rec_img = scale_img(rec_img)

                # test appearance transfer
                if self.parallel:
                    tr_img = vunet.module.transfer(app_img.to(self.device),
                                                   shape_img.to(self.device))
                else:
                    tr_img = vunet.transfer(app_img, shape_img)
                tr_img = scale_img(tr_img)

                # test sampling mode
                if self.parallel:
                    sampled = vunet.module.test_forward(
                        shape_img.to(self.device))
                else:
                    sampled = vunet.test_forward(shape_img)
                sampled = scale_img(sampled)

            # scale also imputs
            app_img = scale_img(app_img_disp)
            timg = scale_img(timg)
            shape_img = scale_img(shape_img)

            writer.add_images(
                "test-reconstruct",
                make_img_grid([
                    timg.to(self.device),
                    shape_img.to(self.device),
                    rec_img.to(self.device),
                ]),
                it,
            )
            writer.add_images(
                "test-transfer",
                make_img_grid([
                    app_img.to(self.device),
                    shape_img.to(self.device),
                    tr_img.to(self.device),
                ]),
                it,
            )
            writer.add_images(
                "test-sample",
                make_img_grid(
                    [shape_img.to(self.device),
                     sampled.to(self.device)]),
            )

        infer_dir = path.join(self.dirs["generated"], "test_inference")
        if not path.isdir(infer_dir):
            os.makedirs(infer_dir)

        @trainer.on(Events.ITERATION_COMPLETED)
        def compute_eval_metrics(engine):
            # computes evaluation metrics and saves checkpoints
            if (engine.state.iteration +
                    1) % self.config["metrics"]["n_it_metrics"] == 0:
                # compute metrics
                vunet.eval()
                tr_imgs = []
                rec_imgs = []
                n_samples = 40 if self.config['general'][
                    'debug'] else self.config["metrics"]["max_n_samples"]
                for i, batch in enumerate(
                        tqdm(
                            met_loader,
                            total=n_samples // met_loader.batch_size,
                            desc=
                            f"Synthesizing {n_samples} images for IS computation."
                        )):
                    if i * met_loader.batch_size >= n_samples:
                        break
                    if self.parallel:
                        imgs = {name: batch[name] for name in self.data_keys}
                    else:
                        imgs = {
                            name: batch[name].to(self.device)
                            for name in self.data_keys
                        }

                    app_img = imgs["app_img"]
                    timg = (imgs["pose_img_inplane"]
                            if self.config["data"]["inplane_normalize"] else
                            imgs["pose_img"])
                    shape_img = imgs["stickman"]

                    with torch.no_grad():
                        rec_img, _, _, _ = vunet(timg, shape_img)
                        tr_img = vunet.transfer(app_img, shape_img)

                    tr_img_cp = deepcopy(tr_img)
                    rec_img_cp = deepcopy(rec_img)
                    rec_imgs.append(rec_img_cp.detach().cpu())
                    tr_imgs.append(tr_img_cp.detach().cpu())

                    del rec_img
                    del timg
                    del shape_img
                    del app_img
                    del tr_img

                tr_imgs = torch.cat(tr_imgs, dim=0)
                rec_imgs = torch.cat(rec_imgs, dim=0)

                tr_dataset = torch.utils.data.TensorDataset(tr_imgs)
                rec_dataset = torch.utils.data.TensorDataset(rec_imgs)

                is_rec, std_rec = inception_score(
                    rec_dataset,
                    self.device,
                    resize=True,
                    batch_size=self.config["metrics"]["test_batch_size"],
                    debug=self.config['general']['debug'])
                is_tr, std_tr = inception_score(
                    tr_dataset,
                    self.device,
                    batch_size=self.config["metrics"]["test_batch_size"],
                    resize=True,
                    debug=self.config['general']['debug'])

                ssim = compute_ssim(vunet,
                                    self.all_devices,
                                    data_keys=self.data_keys,
                                    debug=self.config["general"]["debug"],
                                    **self.config["data"],
                                    **self.config["training"],
                                    **self.config["metrics"])

                # add to tensorboard
                it = engine.state.iteration
                # writer.add_scalar("fid", fid, it)
                writer.add_scalar("ssim", ssim, it)
                writer.add_scalar("is_recon", is_rec, it)
                writer.add_scalar("is_transfer", is_tr, it)
                writer.add_scalar("std_is_recon", std_rec, it)
                writer.add_scalar("std_is_transfer", std_tr, it)

                # save checkpoint to separate dir which contains the checkpoints based on metrics
                save_dir = path.join(self.dirs["ckpt"], "epoch_ckpts")
                os.makedirs(save_dir, exist_ok=True)

                torch.save(
                    vunet.state_dict(),
                    path.join(
                        save_dir,
                        f"model@e{engine.state.epoch}@ssim={ssim}-is={is_rec}.pth",
                    ),
                )
                torch.save(
                    optimizer.state_dict(),
                    path.join(
                        save_dir,
                        f"opt@e{engine.state.epoch}@ssim={ssim}-is={is_rec}.pth",
                    ),
                )

        @trainer.on(Events.STARTED)
        def set_start_it(engine):
            engine.state.iteration = start_it + 1
            print(f"Engine starting from iteration #{engine.state.iteration}.")

        @trainer.on(Events.ITERATION_STARTED)
        def stop(engine):
            it = engine.state.iteration
            if it >= self.config["training"]["end_iteration"]:
                print(
                    f"Current iteration is {it}: Training terminating after this iteration."
                )
                engine.terminate()

        trainer.run(train_loader, max_epochs=n_epoch)
def train(model, state, path, annotations, val_path, val_annotations, resize, max_size, jitter, batch_size, iterations,
          val_iterations, mixed_precision, lr, warmup, milestones, gamma, is_master=True, world=1, use_dali=True,
          verbose=True, metrics_url=None, logdir=None):
    'Train the model on the given dataset'

    # Prepare model
    nn_model = model
    stride = model.stride

    model = convert_fixedbn_model(model)
    if torch.cuda.is_available():
        model = model.cuda()

    # Setup optimizer and schedule
    # optimizer = SGD(model.parameters(), lr=lr, weight_decay=0.0001, momentum=0.9)
    optimizer = Adam(model.parameters(), lr=lr, weight_decay=0.0000001)

    model, optimizer = amp.initialize(model, optimizer,
                                      opt_level='O2' if mixed_precision else 'O0',
                                      keep_batchnorm_fp32=True,
                                      loss_scale=128.0,
                                      verbosity=is_master)

    if world > 1:
        model = DistributedDataParallel(model)
    model.train()

    if 'optimizer' in state:
        optimizer.load_state_dict(state['optimizer'])

    '''
    def schedule(train_iter):
        if warmup and train_iter <= warmup:
            return 0.9 * train_iter / warmup + 0.1
        return gamma ** len([m for m in milestones if m <= train_iter])
    scheduler = LambdaLR(optimizer, schedule)
    '''

    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)

    # Prepare dataset
    if verbose: print('Preparing dataset...')
    data_iterator = (DaliDataIterator if use_dali else DataIterator)(
        path, jitter, max_size, batch_size, stride,
        world, annotations, training=True)
    if verbose: print(data_iterator)

    if verbose:
        print('    device: {} {}'.format(
            world, 'cpu' if not torch.cuda.is_available() else 'gpu' if world == 1 else 'gpus'))
        print('    batch: {}, precision: {}'.format(batch_size, 'mixed' if mixed_precision else 'full'))
        print('Training model for {} iterations...'.format(iterations))

    # Create TensorBoard writer
    if logdir is not None:
        from tensorboardX import SummaryWriter
        if is_master and verbose:
            print('Writing TensorBoard logs to: {}'.format(logdir))
        writer = SummaryWriter(logdir=logdir)

    profiler = Profiler(['train', 'fw', 'bw'])
    iteration = state.get('iteration', 0)
    while iteration < iterations:
        cls_losses, box_losses = [], []
        for i, (data, target) in enumerate(data_iterator):
            # scheduler.step(iteration)

            # Forward pass
            profiler.start('fw')

            optimizer.zero_grad()
            cls_loss, box_loss = model([data, target])
            del data
            profiler.stop('fw')

            # Backward pass
            profiler.start('bw')
            with amp.scale_loss(cls_loss + box_loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            optimizer.step()

            # Reduce all losses
            cls_loss, box_loss = cls_loss.mean().clone(), box_loss.mean().clone()
            if world > 1:
                torch.distributed.all_reduce(cls_loss)
                torch.distributed.all_reduce(box_loss)
                cls_loss /= world
                box_loss /= world
            if is_master:
                cls_losses.append(cls_loss)
                box_losses.append(box_loss)

            if is_master and not isfinite(cls_loss + box_loss):
                raise RuntimeError('Loss is diverging!\n{}'.format(
                    'Try lowering the learning rate.'))

            del cls_loss, box_loss
            profiler.stop('bw')

            iteration += 1
            profiler.bump('train')
            if is_master and (profiler.totals['train'] > 60 or iteration == iterations):
                focal_loss = torch.stack(list(cls_losses)).mean().item()
                box_loss = torch.stack(list(box_losses)).mean().item()
                learning_rate = optimizer.param_groups[0]['lr']
                if verbose:
                    msg = '[{:{len}}/{}]'.format(iteration, iterations, len=len(str(iterations)))
                    msg += ' focal loss: {:.3f}'.format(focal_loss)
                    msg += ', box loss: {:.3f}'.format(box_loss)
                    msg += ', {:.3f}s/{}-batch'.format(profiler.means['train'], batch_size)
                    msg += ' (fw: {:.3f}s, bw: {:.3f}s)'.format(profiler.means['fw'], profiler.means['bw'])
                    msg += ', {:.1f} im/s'.format(batch_size / profiler.means['train'])
                    msg += ', lr: {:.2g}'.format(learning_rate)
                    print(msg, flush=True)

                if logdir is not None:
                    writer.add_scalar('focal_loss', focal_loss, iteration)
                    writer.add_scalar('box_loss', box_loss, iteration)
                    writer.add_scalar('learning_rate', learning_rate, iteration)
                    del box_loss, focal_loss

                if metrics_url:
                    post_metrics(metrics_url, {
                        'focal loss': mean(cls_losses),
                        'box loss': mean(box_losses),
                        'im_s': batch_size / profiler.means['train'],
                        'lr': learning_rate
                    })

                # Save model weights
                state.update({
                    'iteration': iteration,
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                })
                with ignore_sigint():
                    nn_model.save(state)

                profiler.reset()
                del cls_losses[:], box_losses[:]

            if val_annotations and (iteration == iterations or iteration % val_iterations == 0):
                f1_m = infer(model, val_path, None, resize, max_size, batch_size, annotations=val_annotations,
                             mixed_precision=mixed_precision, is_master=is_master, world=world, use_dali=use_dali,
                             is_validation=True, verbose=False)
                if not isinstance(f1_m, str):
                    print('f1_m:' + str(f1_m))
                    scheduler.step(f1_m)
                model.train()

            if iteration == iterations:
                break

    if logdir is not None:
        writer.close()
Exemple #22
0
def train(
        hyp,  # path/to/hyp.yaml or hyp dictionary
        opt,
        device,
        callbacks):
    save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze, delta = \
        Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
        opt.resume, opt.noval, opt.nosave, opt.workers, opt.freeze, opt.delta

    # Directories
    w = save_dir / 'weights'  # weights dir
    (w.parent if evolve else w).mkdir(parents=True, exist_ok=True)  # make dir
    last, best = w / 'last.pt', w / 'best.pt'

    # Hyperparameters
    if isinstance(hyp, str):
        with open(hyp, errors='ignore') as f:
            hyp = yaml.safe_load(f)  # load hyps dict
    LOGGER.info(
        colorstr('hyperparameters: ') + ', '.join(f'{k}={v}'
                                                  for k, v in hyp.items()))

    # Save run settings
    if not evolve:
        with open(save_dir / 'hyp.yaml', 'w') as f:
            yaml.safe_dump(hyp, f, sort_keys=False)
        with open(save_dir / 'opt.yaml', 'w') as f:
            yaml.safe_dump(vars(opt), f, sort_keys=False)

    # Loggers
    data_dict = None
    if RANK in [-1, 0]:
        loggers = Loggers(save_dir, weights, opt, hyp,
                          LOGGER)  # loggers instance
        if loggers.wandb:
            data_dict = loggers.wandb.data_dict
            if resume:
                weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp

        # Register actions
        for k in methods(loggers):
            callbacks.register_action(k, callback=getattr(loggers, k))

    # Config
    plots = not evolve  # create plots
    cuda = device.type != 'cpu'
    init_seeds(1 + RANK)
    with torch_distributed_zero_first(LOCAL_RANK):
        data_dict = data_dict or check_dataset(data)  # check if None
    train_path, val_path, adv_path = data_dict['train'], data_dict[
        'val'], data_dict["adv"]
    nc = 1 if single_cls else int(data_dict['nc'])  # number of classes
    names = ['item'] if single_cls and len(
        data_dict['names']) != 1 else data_dict['names']  # class names
    assert len(
        names
    ) == nc, f'{len(names)} names found for nc={nc} dataset in {data}'  # check
    is_coco = isinstance(val_path, str) and val_path.endswith(
        'coco/val2017.txt')  # COCO dataset

    # Model
    check_suffix(weights, '.pt')  # check weights
    pretrained = weights.endswith('.pt')
    if pretrained:
        with torch_distributed_zero_first(LOCAL_RANK):
            weights = attempt_download(
                weights)  # download if not found locally
        ckpt = torch.load(weights, map_location=device)  # load checkpoint
        model = Model(cfg or ckpt['model'].yaml,
                      ch=3,
                      nc=nc,
                      anchors=hyp.get('anchors')).to(device)  # create
        exclude = [
            'anchor'
        ] if (cfg or hyp.get('anchors')) and not resume else []  # exclude keys
        csd = ckpt['model'].float().state_dict(
        )  # checkpoint state_dict as FP32
        csd = intersect_dicts(csd, model.state_dict(),
                              exclude=exclude)  # intersect
        model.load_state_dict(csd, strict=False)  # load
        LOGGER.info(
            f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}'
        )  # report
    else:
        model = Model(cfg, ch=3, nc=nc,
                      anchors=hyp.get('anchors')).to(device)  # create

    # Freeze
    freeze = [f'model.{x}.' for x in range(freeze)]  # layers to freeze
    for k, v in model.named_parameters():
        v.requires_grad = True  # train all layers
        if any(x in k for x in freeze):
            LOGGER.info(f'freezing {k}')
            v.requires_grad = False

    # Image size
    gs = max(int(model.stride.max()), 32)  # grid size (max stride)
    imgsz = check_img_size(opt.imgsz, gs,
                           floor=gs * 2)  # verify imgsz is gs-multiple

    # Batch size
    if RANK == -1 and batch_size == -1:  # single-GPU only, estimate best batch size
        batch_size = check_train_batch_size(model, imgsz)

    # Optimizer
    nbs = 64  # nominal batch size
    accumulate = max(round(nbs / batch_size),
                     1)  # accumulate loss before optimizing
    hyp['weight_decay'] *= batch_size * accumulate / nbs  # scale weight_decay
    LOGGER.info(f"Scaled weight_decay = {hyp['weight_decay']}")

    g0, g1, g2 = [], [], []  # optimizer parameter groups
    for v in model.modules():
        if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter):  # bias
            g2.append(v.bias)
        if isinstance(v, nn.BatchNorm2d):  # weight (no decay)
            g0.append(v.weight)
        elif hasattr(v, 'weight') and isinstance(
                v.weight, nn.Parameter):  # weight (with decay)
            g1.append(v.weight)

    if opt.adam:
        optimizer = Adam(g0, lr=hyp['lr0'],
                         betas=(hyp['momentum'],
                                0.999))  # adjust beta1 to momentum
    else:
        optimizer = SGD(g0,
                        lr=hyp['lr0'],
                        momentum=hyp['momentum'],
                        nesterov=True)

    optimizer.add_param_group({
        'params': g1,
        'weight_decay': hyp['weight_decay']
    })  # add g1 with weight_decay
    optimizer.add_param_group({'params': g2})  # add g2 (biases)
    LOGGER.info(
        f"{colorstr('optimizer:')} {type(optimizer).__name__} with parameter groups "
        f"{len(g0)} weight, {len(g1)} weight (no decay), {len(g2)} bias")
    del g0, g1, g2

    # Scheduler
    if opt.linear_lr:
        lf = lambda x: (1 - x / (epochs - 1)) * (1.0 - hyp['lrf']) + hyp[
            'lrf']  # linear
    else:
        lf = one_cycle(1, hyp['lrf'], epochs)  # cosine 1->hyp['lrf']
    scheduler = lr_scheduler.LambdaLR(
        optimizer,
        lr_lambda=lf)  # plot_lr_scheduler(optimizer, scheduler, epochs)

    # EMA
    ema = ModelEMA(model) if RANK in [-1, 0] else None

    # Resume
    start_epoch, best_fitness = 0, 0.0
    if pretrained:
        # Optimizer
        if ckpt['optimizer'] is not None:
            optimizer.load_state_dict(ckpt['optimizer'])
            best_fitness = ckpt['best_fitness']

        # EMA
        if ema and ckpt.get('ema'):
            ema.ema.load_state_dict(ckpt['ema'].float().state_dict())
            ema.updates = ckpt['updates']

        # Epochs
        start_epoch = ckpt['epoch'] + 1
        if resume:
            assert start_epoch > 0, f'{weights} training to {epochs} epochs is finished, nothing to resume.'
        if epochs < start_epoch:
            LOGGER.info(
                f"{weights} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {epochs} more epochs."
            )
            epochs += ckpt['epoch']  # finetune additional epochs

        del ckpt, csd

    # DP mode
    if cuda and RANK == -1 and torch.cuda.device_count() > 1:
        LOGGER.warning(
            'WARNING: DP not recommended, use torch.distributed.run for best DDP Multi-GPU results.\n'
            'See Multi-GPU Tutorial at https://github.com/ultralytics/yolov5/issues/475 to get started.'
        )
        model = torch.nn.DataParallel(model)

    # SyncBatchNorm
    if opt.sync_bn and cuda and RANK != -1:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
        LOGGER.info('Using SyncBatchNorm()')

    # Trainloader
    print("Domain-Adversarial training")
    (train_loader_s, dataset_s, train_loader_t,
     dataset_t) = create_adv_dataloaders(
         train_path,
         adv_path,
         imgsz,
         batch_size // WORLD_SIZE,
         gs,
         single_cls,
         hyp=hyp,
         augment=True,
         cache=opt.cache,
         rect=opt.rect,
         rank=RANK,
         workers=workers,
         image_weights=opt.image_weights,
         quad=opt.quad,
         prefix=colorstr("train: "),
     )
    mlc = int(np.concatenate(dataset_s.labels, 0)[:,
                                                  0].max())  # max label class
    nb = min(len(train_loader_s), len(train_loader_t))  # number of batches
    assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}'

    # Process 0
    if RANK in [-1, 0]:
        val_loader = create_dataloader(val_path,
                                       imgsz,
                                       batch_size // WORLD_SIZE * 2,
                                       gs,
                                       single_cls,
                                       hyp=hyp,
                                       cache=None if noval else opt.cache,
                                       rect=True,
                                       rank=-1,
                                       workers=workers,
                                       pad=0.5,
                                       prefix=colorstr('val: '))[0]

        if not resume:
            labels = np.concatenate(dataset_s.labels, 0)
            # c = torch.tensor(labels[:, 0])  # classes
            # cf = torch.bincount(c.long(), minlength=nc) + 1.  # frequency
            # model._initialize_biases(cf.to(device))
            if plots:
                plot_labels(labels, names, save_dir)

            # Anchors
            if not opt.noautoanchor:
                check_anchors(dataset_s,
                              model=model,
                              thr=hyp['anchor_t'],
                              imgsz=imgsz)
            model.half().float()  # pre-reduce anchor precision

        callbacks.run('on_pretrain_routine_end')

    # DDP mode
    if cuda and RANK != -1:
        model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)

    # Model attributes
    nl = de_parallel(
        model).model[-1].nl  # number of detection layers (to scale hyps)
    hyp['box'] *= 3 / nl  # scale to layers
    hyp['cls'] *= nc / 80 * 3 / nl  # scale to classes and layers
    hyp['obj'] *= (imgsz / 640)**2 * 3 / nl  # scale to image size and layers
    hyp['label_smoothing'] = opt.label_smoothing
    model.nc = nc  # attach number of classes to model
    model.hyp = hyp  # attach hyperparameters to model
    model.class_weights = labels_to_class_weights(
        dataset_s.labels, nc).to(device) * nc  # attach class weights
    model.names = names

    # Start training
    t0 = time.time()
    nw = max(round(hyp['warmup_epochs'] * nb),
             1000)  # number of warmup iterations, max(3 epochs, 1k iterations)
    # nw = min(nw, (epochs - start_epoch) / 2 * nb)  # limit warmup to < 1/2 of training
    last_opt_step = -1
    maps = np.zeros(nc)  # mAP per class
    results = (0, 0, 0, 0, 0, 0, 0
               )  # P, R, [email protected], [email protected], val_loss(box, obj, cls)
    scheduler.last_epoch = start_epoch - 1  # do not move
    scaler = amp.GradScaler(enabled=cuda)
    stopper = EarlyStopping(patience=opt.patience)
    compute_loss = ComputeLoss(model)  # init loss class
    compute_domain_loss = ComputeDomainLoss(model)  # init domain loss class
    LOGGER.info(
        f'Image sizes {imgsz} train, {imgsz} val\n'
        f'Using {train_loader_s.num_workers + train_loader_t.num_workers * WORLD_SIZE} dataloader workers\n'
        f"Logging results to {colorstr('bold', save_dir)}\n"
        f'Starting training for {epochs} epochs...')
    max_iterations = nb * (epochs - start_epoch)
    for epoch in range(
            start_epoch, epochs
    ):  # epoch ------------------------------------------------------------------
        model.train()

        # Update image weights (optional, single-GPU only)
        if opt.image_weights:
            cw = model.class_weights.cpu().numpy() * (
                1 - maps)**2 / nc  # class weights
            iw = labels_to_image_weights(dataset_s.labels,
                                         nc=nc,
                                         class_weights=cw)  # image weights
            dataset_s.indices = random.choices(
                range(dataset_s.n), weights=iw,
                k=dataset_s.n)  # rand weighted idx

        # Update mosaic border (optional)
        # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
        # dataset.mosaic_border = [b - imgsz, -b]  # height, width borders

        mloss = torch.zeros(3, device=device)  # mean losses
        madvloss = torch.zeros(3, device=device)  # mean adversarial losses
        madvaccuracy = torch.zeros(
            3, device=device)  # mean adversarial accuracies
        if RANK != -1:
            train_loader_s.sampler.set_epoch(epoch)
            train_loader_t.sampler.set_epoch(epoch)
        # pbar = enumerate(train_loader)
        pbar = enumerate(zip(train_loader_s, train_loader_t))
        LOGGER.info(("\n" + "%10s" * 13) %
                    ("Epoch", "gpu_mem", "box", "obj", "cls", "l_small",
                     "l_medium", "l_large", "a_small", "a_medium", "a_large",
                     "labels", "img_size"))
        if RANK in [-1, 0]:
            pbar = tqdm(
                pbar, total=nb,
                bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}')  # progress bar
        optimizer.zero_grad()
        for i, (
            (imgs_s, targets_s, paths_s, _), (imgs_t, paths_t, _)
        ) in pbar:  # batch -------------------------------------------------------------
            ni = i + nb * epoch  # number integrated batches (since train start)
            imgs = torch.cat([imgs_s, imgs_t])
            imgs = imgs.to(device, non_blocking=True).float(
            ) / 255  # uint8 to float32, 0-255 to 0.0-1.0

            # Warmup
            if ni <= nw:
                xi = [0, nw]  # x interp
                # compute_loss.gr = np.interp(ni, xi, [0.0, 1.0])  # iou loss ratio (obj_loss = 1.0 or iou)
                accumulate = max(
                    1,
                    np.interp(ni, xi, [1, nbs / batch_size]).round())
                for j, x in enumerate(optimizer.param_groups):
                    # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
                    x['lr'] = np.interp(ni, xi, [
                        hyp['warmup_bias_lr'] if j == 2 else 0.0,
                        x['initial_lr'] * lf(epoch)
                    ])
                    if 'momentum' in x:
                        x['momentum'] = np.interp(
                            ni, xi, [hyp['warmup_momentum'], hyp['momentum']])

            # Multi-scale
            if opt.multi_scale:
                sz = random.randrange(imgsz * 0.5,
                                      imgsz * 1.5 + gs) // gs * gs  # size
                sf = sz / max(imgs.shape[2:])  # scale factor
                if sf != 1:
                    ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]
                          ]  # new shape (stretched to gs-multiple)
                    imgs = nn.functional.interpolate(imgs,
                                                     size=ns,
                                                     mode='bilinear',
                                                     align_corners=False)

            # Forward
            with amp.autocast(enabled=cuda):
                r = ni / max_iterations
                gamma = 2 / (1 + math.exp(-delta * r)) - 1

                pred_s, domain_pred_s = model(imgs[:batch_size // 2 //
                                                   WORLD_SIZE],
                                              gamma=gamma,
                                              domain=0,
                                              epoch=epoch)  # forward
                pred_t, domain_pred_t = model(imgs[batch_size // 2 //
                                                   WORLD_SIZE:],
                                              gamma=gamma,
                                              domain=1,
                                              epoch=epoch)  # forward

                loss, loss_items = compute_loss(
                    pred_s, targets_s.to(device))  # loss scaled by batch_size

                domain_loss, domain_loss_items, domain_accuracy_items = compute_domain_loss(
                    domain_pred_s, domain_pred_t)  # loss scaled by batch_size

                if RANK != -1:
                    loss *= WORLD_SIZE  # gradient averaged between devices in DDP mode
                if opt.quad:
                    loss *= 4.
            total_loss = loss + domain_loss

            # Backward
            scaler.scale(total_loss).backward()

            # Optimize
            if ni - last_opt_step >= accumulate:
                scaler.step(optimizer)  # optimizer.step
                scaler.update()
                optimizer.zero_grad()
                if ema:
                    ema.update(model)
                last_opt_step = ni

            # Log
            if RANK in [-1, 0]:
                mloss = (mloss * i + loss_items) / (i + 1
                                                    )  # update mean losses
                madvloss = (madvloss * i + domain_loss_items) / (
                    i + 1)  # update mean losses
                madvaccuracy = (madvaccuracy * i + domain_accuracy_items) / (
                    i + 1)  # update mean accuracies
                mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G'  # (GB)
                pbar.set_description(
                    ('%10s' * 2 + '%10.4g' * 11) %
                    (f'{epoch}/{epochs - 1}', mem, *mloss, *madvloss,
                     *madvaccuracy, targets_s.shape[0], imgs.shape[-1]))
                callbacks.run('on_train_batch_end', ni, model,
                              imgs_s.float().to(device), targets_s, paths_s,
                              plots, opt.sync_bn)
            # end batch ------------------------------------------------------------------------------------------------

        # Scheduler
        lr = [x['lr'] for x in optimizer.param_groups]  # for loggers
        scheduler.step()

        if RANK in [-1, 0]:
            # mAP
            callbacks.run('on_train_epoch_end', epoch=epoch)
            ema.update_attr(model,
                            include=[
                                'yaml', 'nc', 'hyp', 'names', 'stride',
                                'class_weights'
                            ])
            final_epoch = (epoch + 1 == epochs) or stopper.possible_stop
            if not noval or final_epoch:  # Calculate mAP
                results, maps, _ = adv_val.run(data_dict,
                                               batch_size=batch_size //
                                               WORLD_SIZE * 2,
                                               imgsz=imgsz,
                                               model=ema.ema,
                                               single_cls=single_cls,
                                               dataloader=val_loader,
                                               save_dir=save_dir,
                                               plots=False,
                                               callbacks=callbacks,
                                               compute_loss=compute_loss)

            # Update best mAP
            fi = fitness(np.array(results).reshape(
                1, -1))  # weighted combination of [P, R, [email protected], [email protected]]
            if fi > best_fitness:
                best_fitness = fi
            log_vals = list(mloss) + list(results) + lr + list(
                madvloss) + list(madvaccuracy)
            callbacks.run('on_fit_epoch_end', log_vals, epoch, best_fitness,
                          fi)

            # Save model
            if (not nosave) or (final_epoch and not evolve):  # if save
                ckpt = {
                    'epoch': epoch,
                    'best_fitness': best_fitness,
                    'model': deepcopy(de_parallel(model)).half(),
                    'ema': deepcopy(ema.ema).half(),
                    'updates': ema.updates,
                    'optimizer': optimizer.state_dict(),
                    'wandb_id':
                    loggers.wandb.wandb_run.id if loggers.wandb else None,
                    'date': datetime.now().isoformat()
                }

                # Save last, best and delete
                torch.save(ckpt, last)
                if best_fitness == fi:
                    torch.save(ckpt, best)
                if (epoch > 0) and (opt.save_period >
                                    0) and (epoch % opt.save_period == 0):
                    torch.save(ckpt, w / f'epoch{epoch}.pt')
                del ckpt
                callbacks.run('on_model_save', last, epoch, final_epoch,
                              best_fitness, fi)

            # Stop Single-GPU
            if RANK == -1 and stopper(epoch=epoch, fitness=fi):
                break

            # Stop DDP TODO: known issues shttps://github.com/ultralytics/yolov5/pull/4576
            # stop = stopper(epoch=epoch, fitness=fi)
            # if RANK == 0:
            #    dist.broadcast_object_list([stop], 0)  # broadcast 'stop' to all ranks

        # Stop DPP
        # with torch_distributed_zero_first(RANK):
        # if stop:
        #    break  # must break all DDP ranks

        # end epoch ----------------------------------------------------------------------------------------------------
    # end training -----------------------------------------------------------------------------------------------------
    if RANK in [-1, 0]:
        LOGGER.info(
            f'\n{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.'
        )
        for f in last, best:
            if f.exists():
                strip_optimizer(f)  # strip optimizers
                if f is best:
                    LOGGER.info(f'\nValidating {f}...')
                    results, _, _ = adv_val.run(
                        data_dict,
                        batch_size=batch_size // WORLD_SIZE * 2,
                        imgsz=imgsz,
                        model=attempt_load(f, device, fuse=False).half(
                        ),  # set fuse to False since we have multiple BatchNorm
                        iou_thres=0.65 if is_coco else
                        0.60,  # best pycocotools results at 0.65
                        single_cls=single_cls,
                        dataloader=val_loader,
                        save_dir=save_dir,
                        save_json=is_coco,
                        verbose=True,
                        plots=True,
                        callbacks=callbacks,
                        compute_loss=compute_loss)  # val best model with plots
                    if is_coco:
                        callbacks.run('on_fit_epoch_end',
                                      list(mloss) + list(results) + lr, epoch,
                                      best_fitness, fi)

        callbacks.run('on_train_end', last, best, plots, epoch, results)
        LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")

    torch.cuda.empty_cache()
    return results
Exemple #23
0
class BYOLTrainer():
    def __init__(self,
                 encoder, representation_size, projection_size, projection_hidden_size,
                 train_dataloader, prepare_views, total_epochs, warmup_epochs, base_lr, base_momentum,
                 batch_size=256, decay='cosine', n_decay=1.5, m_decay='cosine',
                 optimizer_type="lars", momentum=1.0, weight_decay=1.0, exclude_bias_and_bn=False,
                 transform=None, transform_1=None, transform_2=None, symmetric_loss=False,
                 world_size=1, rank=0, distributed=False, gpu=0, master_gpu=0, port='12355',
                 ckpt_path="./models/ckpt-%d.pt", log_step=1, log_dir=None, **kwargs):

        # device parameters
        self.world_size = world_size
        self.rank = rank
        self.gpu = gpu
        self.master_gpu = master_gpu
        self.distributed = distributed

        if torch.cuda.is_available():
            self.device = torch.device(f'cuda:{self.gpu}')
            torch.cuda.set_device(self.device)
        else:
            self.device = torch.device('cpu')

        print('Using %r.' %self.device)

        # checkpoint
        self.ckpt_path = ckpt_path

        # build network
        self.representation_size = representation_size
        self.projection_size = projection_size
        self.projection_hidden_size = projection_hidden_size
        self.model = self.build_model(encoder)

        if self.distributed:
            os.environ['MASTER_ADDR'] = 'localhost'
            os.environ['MASTER_PORT'] = port
            dist.init_process_group(backend='nccl', init_method='env://', rank=self.rank, world_size=self.world_size)
            self.group = dist.new_group()

            self.model = nn.SyncBatchNorm.convert_sync_batchnorm(self.model)
            self.model = DDP(self.model, device_ids=[self.gpu], find_unused_parameters=True)

        # dataloaders
        self.train_dataloader = train_dataloader
        self.prepare_views = prepare_views # outputs view1 and view2 (pre-gpu-transform)

        # transformers
        # these are on gpu transforms! can have cpu transform in dataloaders
        self.transform_1 = transform_1 if transform_1 is not None else transform # class 1 of transformations
        self.transform_2 = transform_2 if transform_2 is not None else transform # class 2 of transformations
        assert (self.transform_1 is None) == (self.transform_2 is None)

        # training parameters
        self.total_epochs = total_epochs
        self.warmup_epochs = warmup_epochs

        # todo fix batch shape (double batch loader)
        self.train_batch_size = batch_size
        self.global_batch_size = self.world_size * self.train_batch_size

        self.num_examples = len(self.train_dataloader.dataset)
        self.warmup_steps = self.warmup_epochs * self.num_examples // self.global_batch_size
        self.total_steps = self.total_epochs * self.num_examples // self.global_batch_size

        self.step = 0
        base_lr = base_lr / 256
        self.max_lr = base_lr * self.global_batch_size

        self.base_mm = base_momentum

        assert decay in ['cosine', 'poly']
        self.decay = decay
        self.n_decay = n_decay

        assert m_decay in ['cosine', 'cste']
        self.m_decay = m_decay

        # configure optimizer
        self.momentum = momentum
        self.weight_decay = weight_decay
        self.exclude_bias_and_bn = exclude_bias_and_bn

        if self.exclude_bias_and_bn:
            if not self.distributed:
                params = self._collect_params(self.model.trainable_modules)
            else:
                # todo make sure this is correct
                params = self._collect_params(self.model.module.trainable_module_list)
        else:
            params = self.model.parameters()

        if optimizer_type == "lars":
            self.optimizer = LARS(params, lr=self.max_lr, momentum=self.momentum, weight_decay=self.weight_decay)
        elif optimizer_type == "sgd":
            self.optimizer = SGD(params, lr=base_lr, momentum=self.momentum, weight_decay=self.weight_decay)
        elif optimizer_type == "adam":
            if momentum != 1.0:
                warnings.warn("Adam optimizer doesn't use momentum. Momentum %.2f will be ignored." % momentum)
            self.optimizer = Adam(params, lr=base_lr, weight_decay=self.weight_decay)
        else:
            raise ValueError("Optimizer type needs to be 'lars', 'sgd' or 'adam', got (%s)." % optimizer_type)

        self.loss = CosineLoss().to(self.device)
        self.symmetric_loss = symmetric_loss

        # logging
        self.log_step = log_step
        if self.rank == 0:
            self.writer = SummaryWriter(log_dir)

    def build_model(self, encoder):
        projector = MLP3(self.representation_size, self.projection_size, self.projection_hidden_size)
        predictor = MLP3(self.projection_size, self.projection_size, self.projection_hidden_size)
        net = BYOL(encoder, projector, predictor)
        return net.to(self.device)

    def _collect_params(self, model_list):
        """
        exclude_bias_and bn: exclude bias and bn from both weight decay and LARS adaptation
            in the PyTorch implementation of ResNet, `downsample.1` are bn layers
        """
        param_list = []
        for model in model_list:
            for name, param in model.named_parameters():
                if self.exclude_bias_and_bn and ('bn' in name or 'downsample.1' in name or 'bias' in name):
                    param_dict = {'params': param, 'weight_decay': 0., 'lars_exclude': True}
                else:
                    param_dict = {'params': param}
                param_list.append(param_dict)
        return param_list

    def _cosine_decay(self, step):
        return 0.5 * self.max_lr * (1 + np.cos((step - self.warmup_steps) * np.pi / (self.total_steps - self.warmup_steps)))

    def _poly_decay(self, step):
        return self.max_lr * (1 - ((step - self.warmup_steps) / (self.total_steps- self.warmup_steps)) ** self.n_decay)

    def update_learning_rate(self, step, decay='poly'):
        """learning rate warm up and decay"""
        if step <= self.warmup_steps:
            lr = self.max_lr * step / self.warmup_steps
        else:
            if self.decay == 'cosine':
                lr = self._cosine_decay(step)
            elif self.decay == 'poly':
                lr = self._poly_decay(step)
            else:
                raise AttributeError
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

    def update_momentum(self, step):
        if self.m_decay == 'cosine':
            self.mm = 1 - (1 - self.base_mm) * (np.cos(np.pi * step / self.total_steps) + 1) / 2
        elif self.m_decay == 'cste':
            self.mm = self.base_mm
        else:
            raise AttributeError

    def save_checkpoint(self, epoch):
        if self.rank == 0:
            state = {
                     'epoch': epoch,
                     'steps': self.step,
                     'model': self.model.state_dict(),
                     'optimizer': self.optimizer.state_dict(),
                    }
            torch.save(state, self.ckpt_path %(epoch))

    def load_checkpoint(self, epoch):
        model_path = self.ckpt_path %(epoch)
        map_location = {"cuda:{}": "cuda:{}".format(self.master_gpu, self.gpu)}
        map_location = "cuda:{}".format(self.gpu)
        checkpoint = torch.load(model_path, map_location=map_location)

        self.step = checkpoint['steps']
        self.model.load_state_dict(checkpoint['model'], strict=False)

        self.optimizer.load_state_dict(checkpoint['optimizer'])

    def cleanup(self):
        dist.destroy_process_group()

    def forward_loss(self, preds, targets):
        loss = self.loss(preds, targets)
        return loss

    def update_target_network(self):
        if not self.distributed:
            self.model.update_target_network(self.mm)
        else:
            self.model.module.update_target_network(self.mm)

    def log_schedule(self, loss):
        self.writer.add_scalar('lr', self.optimizer.param_groups[0]['lr'], self.step)
        self.writer.add_scalar('mm', self.mm, self.step)
        self.writer.add_scalar('loss', loss, self.step)

    def train_epoch(self):
        self.model.train()
        for inputs in self.train_dataloader:
            # update parameters
            self.update_learning_rate(self.step)
            self.update_momentum(self.step)

            inputs = self.prepare_views(inputs)
            view1 = inputs['view1'].to(self.device)
            view2 = inputs['view2'].to(self.device)

            if self.transform_1 is not None:
                # apply transforms
                view1 = self.transform_1(view1)
                view2 = self.transform_2(view2)

            # forward
            outputs = self.model({'online_view': view1, 'target_view':view2})
            loss = self.forward_loss(outputs['online_q'], outputs['target_z'])
            if self.symmetric_loss:
                outputs = self.model({'online_view': view2, 'target_view': view1})
                loss += self.forward_loss(outputs['online_q'], outputs['target_z'])
                loss /= 2

            # backprop online network
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # update moving average
            self.update_target_network()

            # log
            if self.step % self.log_step == 0 and self.rank == 0:
                self.log_schedule(loss=loss.item())

            # update parameters
            self.step += 1
Exemple #24
0
    # text_generator = TextGenerator(text_encoder.embed.num_embeddings).cuda()
    # score_model = ScoreModel(30522, 256, num_heads=1).cuda()
    # category_embedding = CategoryEmbedding(256).cuda()

    optimizer = Adam(image_encoder.get_params() + text_encoder.get_params() +
                     score_model.get_params())

    if start_epoch > 0 and local_rank == 0:
        checkpoints = torch.load(
            os.path.join(checkpoints_dir,
                         'model-epoch{}.pth'.format(start_epoch)), 'cpu')
        text_encoder.load_state_dict(checkpoints['query'])
        image_encoder.load_state_dict(checkpoints['item'])
        score_model.load_state_dict(checkpoints['score'])
        # text_generator.load_state_dict(checkpoints['generator'])
        optimizer.load_state_dict(checkpoints['optimizer'])
        print("load checkpoints")
    # generator = iterate_minibatches(iters=(30 - start_epoch) * len(loader), batch_size=256, num_workers=8, root_dir='/home/dingyuhui/dataset/kdd-data', use_bert=use_bert)

    scheduler = ExponentialLR(optimizer, 0.95, last_epoch=start_epoch - 1)
    text_encoder = nn.parallel.DistributedDataParallel(
        text_encoder,
        find_unused_parameters=True,
        device_ids=[local_rank],
        output_device=local_rank)
    image_encoder = nn.parallel.DistributedDataParallel(
        image_encoder,
        find_unused_parameters=True,
        device_ids=[local_rank],
        output_device=local_rank)
    score_model = nn.parallel.DistributedDataParallel(
Exemple #25
0
def test(config, epoch_to_load, lr=1e-4, batch_size=64):
    inventory = config.inventory
    gold_folder = os.path.join(config.data_folder,
                               'gold/{}/'.format(inventory))

    test_list = config.tests
    exp_folder = config.experiment_folder
    dev_name = config.dev_name
    text_input_folder = os.path.join(config.data_folder,
                                     'input/text_files/{}/'.format(inventory))
    input_folder = os.path.join(config.data_folder,
                                'input/matrices/{}/'.format(inventory))

    mapping = pkl.load(open(config.mapping_path, 'rb'))
    domains_vocab_path = os.path.join(text_input_folder, 'domains.pkl')
    if config.finegrained:
        domains_vocab_path = os.path.join(text_input_folder, 'sensekeys.pkl')

    domains_vocab = pkl.load(open(domains_vocab_path, 'rb'))
    results_folder = os.path.join(exp_folder, 'results/')

    labels = sorted([x for x in domains_vocab if x != 'untagged'])
    labels_dict = {label: k + 1 for k, label in enumerate(labels)}
    labels_dict[None] = 0
    reverse_labels_dict = {v: k for k, v in labels_dict.items()}

    if config.model_name == 'BertLSTM':
        model = BertLSTM(len(labels_dict))

    elif config.model_name == 'BertDense':
        model = BertDense(len(labels_dict), 'bert-large-cased')

    optimizer = Adam(model.parameters(), lr=lr)
    path_checkpoints = os.path.join(config.experiment_folder, 'weights',
                                    'checkpoint_{}.tar'.format(epoch_to_load))
    checkpoint = torch.load(path_checkpoints)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    model.eval()

    cor, tot = 0, 0
    wrong_labels_distr = {}
    for name in test_list:
        x = pkl.load(
            open(os.path.join(input_folder, '{}_words.pkl'.format(name)),
                 'rb')).tolist()
        y = pkl.load(
            open(os.path.join(input_folder, '{}_domains.pkl'.format(name)),
                 'rb')).tolist()

        if config.finegrained:
            y = pkl.load(
                open(
                    os.path.join(input_folder,
                                 '{}_sensekeys.pkl'.format(name)),
                    'rb')).tolist()

        y_idx = process_labels(y, labels_dict).cpu().numpy().tolist()
        tokens = process_tokens(
            pkl.load(
                open(os.path.join(input_folder, '{}_tokens.pkl'.format(name)),
                     'rb')).tolist())

        file_txt = os.path.join(text_input_folder, '{}_input.txt'.format(name))
        token2lemma = {
            line.strip().split()[0]: line.strip().split()[1]
            for line in open(file_txt).readlines()
        }
        output_file = os.path.join(results_folder,
                                   '{}_output.tsv'.format(name))
        candidate_domains = utils.build_possible_senses(
            labels_dict, os.path.join(text_input_folder, 'semcor_input.txt'))

        gold_dictionary = {
            line.strip().split()[0]: line.strip().split()[1:]
            for line in open(
                os.path.join(gold_folder, '{}.gold.txt'.format(name)))
        }

        c, t = 0, 0
        with open(output_file, 'w') as fw:
            for i in tqdm.tqdm(range(0, len(x), batch_size)):
                inputs = x[i:i + batch_size]
                labels_idx = y_idx[i:i + batch_size]
                token_batch = tokens[i:i + batch_size]
                mask = utils.build_mask(words=inputs,
                                        true_y=labels_idx,
                                        labels_dict=labels_dict,
                                        tokens=token_batch,
                                        file_txt=file_txt,
                                        candidate=candidate_domains)

                eval_out = torch.exp(model.eval()(inputs))
                eval_out *= mask.cuda()
                values, predicted = torch.max(eval_out, 2)
                for id_sent, token_sent in enumerate(token_batch):
                    for id_word, token in enumerate(token_sent):
                        if not token is None and not token == 'untagged':
                            gold_label = gold_dictionary[token]
                            pred_label = reverse_labels_dict[predicted[
                                id_sent, id_word].item()]
                            if pred_label == None:
                                pred_label = utils.getMFS(
                                    token, mapping, file_txt)
                                scored_labels = None
                            else:
                                scores_idx = torch.nonzero(eval_out[id_sent,
                                                                    id_word])
                                scored_ = [(reverse_labels_dict[idx.item()],
                                            eval_out[id_sent, id_word,
                                                     idx].item())
                                           for idx in scores_idx]
                                scored_labels = sorted(scored_,
                                                       key=lambda x: x[1],
                                                       reverse=True)
                            if pred_label in gold_label:
                                c += 1
                                fw.write('c\t')
                            else:
                                if not scored_labels == None:
                                    for gl in gold_label:
                                        if not gl in wrong_labels_distr:
                                            wrong_labels_distr[gl] = {}
                                        if not pred_label in wrong_labels_distr[
                                                gl]:
                                            wrong_labels_distr[gl][
                                                pred_label] = 0
                                        wrong_labels_distr[gl][pred_label] += 1
                                fw.write('w\t')
                            t += 1
                            fw.write(name + '.' + token + '\t' +
                                     token2lemma[token] + '\t' + 'pred##' +
                                     pred_label + '\t')
                            fw.write(
                                ' '.join(['gold##' + gl
                                          for gl in gold_label]) + '\t')
                            if not scored_labels is None:
                                fw.write(' '.join([
                                    label + '##' + str(score)
                                    for label, score in scored_labels
                                ]) + '\t')
                                content_words = [
                                    word if index != id_word else
                                    '<tag>{}</tag>'.format(word)
                                    for index, word in enumerate(x[id_sent])
                                    if word != 'PADDING'
                                ]
                                fw.write(' '.join(content_words))
                            fw.write('\n')
        if name != dev_name:
            cor += c
            tot += t
        f1 = np.round(c / t, 3)
        print(name, f1)
    print('ALL', np.round(cor / tot, 3))
Exemple #26
0
def main(args):
    config = load_config(args.config)
    print(config)

    log = Logs(os.path.join(args.out, "log"))

    if torch.cuda.is_available():
        device = torch.device("cuda")

        torch.backends.cudnn.benchmark = True
        log.log("RoboSat - training on {} GPUs, with {} workers".format(
            torch.cuda.device_count(), args.workers))
    else:
        device = torch.device("cpu")
        log.log("RoboSat - training on CPU, with {} workers".format(
            args.workers))

    num_classes = len(config["classes"])
    num_channels = 0
    for channel in config["channels"]:
        num_channels += len(channel["bands"])
    pretrained = config["model"]["pretrained"]
    encoder = config["model"]["encoder"]

    models = [
        name for _, name, _ in pkgutil.iter_modules(
            [os.path.dirname(robosat_pink.models.__file__)])
    ]
    if config["model"]["name"] not in [model for model in models]:
        sys.exit("Unknown model, thoses available are {}".format(
            [model for model in models]))

    model_module = import_module("robosat_pink.models.{}".format(
        config["model"]["name"]))
    net = getattr(model_module, "{}".format(config["model"]["name"].title()))(
        num_classes=num_classes,
        num_channels=num_channels,
        encoder=encoder,
        pretrained=pretrained).to(device)

    net = torch.nn.DataParallel(net)
    optimizer = Adam(net.parameters(),
                     lr=config["model"]["lr"],
                     weight_decay=config["model"]["decay"])

    resume = 0

    # check checkpoint situation  + load if ncessary
    checkpoint = None  # no checkpoint
    if args.checkpoint:  # command line checkpoint
        checkpoint = args.checkpoint
    try:  # config file checkpoint
        checkpoint = config["checkpoint"]['path']
    except:
        # no checkpoint in config file
        pass

    S3_CHECKPOINT = False
    if checkpoint:

        if checkpoint.startswith("s3://"):
            S3_CHECKPOINT = True
            # load from s3
            checkpoint = checkpoint[5:]
            sess = boto3.Session(profile_name=config['dataset']['aws_profile'])
            fs = s3fs.S3FileSystem(session=sess)
            s3ckpt = s3fs.S3File(fs, checkpoint, 'rb')

        def map_location(storage, _):
            return storage.cuda() if torch.cuda.is_available(
            ) else storage.cpu()

    if checkpoint is not None:

        def map_location(storage, _):
            return storage.cuda() if torch.cuda.is_available(
            ) else storage.cpu()

        try:
            if S3_CHECKPOINT:
                with s3fs.S3File(fs, checkpoint, 'rb') as C:
                    state = torch.load(io.BytesIO(C.read()),
                                       map_location=map_location)
            else:
                state = torch.load(checkpoint)
            optimizer.load_state_dict(state['optimizer'])
            net.load_state_dict(state['state_dict'])
            net.to(device)
        except FileNotFoundError as f:
            print("{} checkpoint not found.".format(CHECKPOINT))

        log.log("Using checkpoint: {}".format(checkpoint))

    losses = [
        name for _, name, _ in pkgutil.iter_modules(
            [os.path.dirname(robosat_pink.losses.__file__)])
    ]
    if config["model"]["loss"] not in [loss for loss in losses]:
        sys.exit("Unknown loss, thoses available are {}".format(
            [loss for loss in losses]))

    loss_module = import_module("robosat_pink.losses.{}".format(
        config["model"]["loss"]))
    criterion = getattr(loss_module, "{}".format(
        config["model"]["loss"].title()))().to(device)

    train_loader, val_loader = get_dataset_loaders(config,
                                                   args.workers,
                                                   idDir=args.out)

    if resume >= config["model"]["epochs"]:
        sys.exit(
            "Error: Epoch {} set in {} already reached by the checkpoint provided"
            .format(config["model"]["epochs"], args.config))

    log.log("")
    log.log("--- Input tensor from Dataset: {} ---".format(
        config["dataset"]["image_bucket"] + '/' +
        config['dataset']['imagery_directory_regex']))

    log.log("")
    log.log("--- Hyper Parameters ---")
    log.log("Model:\t\t\t {}".format(config["model"]["name"]))
    log.log("Encoder model:\t\t {}".format(config["model"]["encoder"]))
    log.log("Loss function:\t\t {}".format(config["model"]["loss"]))
    log.log("ResNet pre-trained:\t {}".format(config["model"]["pretrained"]))
    log.log("Batch Size:\t\t {}".format(config["model"]["batch_size"]))
    log.log("Tile Size:\t\t {}".format(config["model"]["tile_size"]))
    log.log("Data Augmentation:\t {}".format(
        config["model"]["data_augmentation"]))
    log.log("Learning Rate:\t\t {}".format(config["model"]["lr"]))
    log.log("Weight Decay:\t\t {}".format(config["model"]["decay"]))
    log.log("")

    for epoch in range(resume, config["model"]["epochs"]):

        log.log("---")
        log.log("Epoch: {}/{}".format(epoch + 1, config["model"]["epochs"]))

        train_hist = train(train_loader, num_classes, device, net, optimizer,
                           criterion)
        log.log(
            "Train    loss: {:.4f}, mIoU: {:.3f}, IoU: {:.3f}, precision:  {:.3f}, recall: {:.3f}"
            .format(
                train_hist["loss"],
                train_hist["miou"],
                train_hist["fg_iou"],
                train_hist["precision"],
                train_hist["recall"],
            ))

        val_hist = validate(val_loader, num_classes, device, net, criterion)
        log.log(
            "Validate loss: {:.4f}, mIoU: {:.3f}, IoU: {:.3f}, precision:  {:.3f}, recall: {:.3f}"
            .format(
                train_hist["loss"],
                train_hist["miou"],
                train_hist["fg_iou"],
                train_hist["precision"],
                train_hist["recall"],
            ))

        states = {
            "epoch": epoch + 1,
            "state_dict": net.state_dict(),
            "optimizer": optimizer.state_dict()
        }
        checkpoint_path = os.path.join(
            args.out, "checkpoint-{:05d}-of-{:05d}.pth".format(
                epoch + 1, config["model"]["epochs"]))
        torch.save(states, checkpoint_path)
Exemple #27
0
def test_few_shot(config, epoch_to_load, k, lr=1e-4):

    inventory = config.inventory
    data_folder = config.data_folder
    exp_folder = config.experiment_folder
    test_list = config.tests
    dev_name = config.dev_name

    gold_folder = os.path.join(data_folder, 'gold/{}/'.format(inventory))
    text_input_folder = os.path.join(data_folder,
                                     'input/text_files/{}/'.format(inventory))
    input_folder = os.path.join(data_folder,
                                'input/matrices/{}/'.format(inventory))
    input_semcor_k = os.path.join(text_input_folder,
                                  'semcor_input_{}.txt'.format(k))

    domains_vocab_path = os.path.join(config.all_words_folder, 'domains.pkl')
    domains_vocab = pkl.load(open(domains_vocab_path, 'rb'))
    results_folder = os.path.join(exp_folder, 'results/{}/'.format(k))

    labels = sorted([x for x in domains_vocab if x != 'untagged'])
    labels_dict = {label: k + 1 for k, label in enumerate(labels)}
    labels_dict[None] = 0
    reverse_labels_dict = {v: k for k, v in labels_dict.items()}

    if config.model_name == 'BertLSTM':
        model = BertLSTM(len(labels_dict))

    elif config.model_name == 'BertDense':
        model = BertDense(len(labels_dict))

    optimizer = Adam(model.parameters(), lr=lr)

    path_checkpoints = os.path.join(exp_folder, 'weights', '{}'.format(k),
                                    'checkpoint_{}.tar'.format(epoch_to_load))
    checkpoint = torch.load(path_checkpoints)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    model.eval()

    cor, tot = 0, 0
    for name in test_list:
        x = pkl.load(
            open(os.path.join(input_folder, '{}_words.pkl'.format(name)),
                 'rb')).tolist()
        y = pkl.load(
            open(os.path.join(input_folder, '{}_domains.pkl'.format(name)),
                 'rb')).tolist()
        file_txt = os.path.join(text_input_folder, '{}_input.txt'.format(name))

        tokens = process_tokens(
            pkl.load(
                open(os.path.join(input_folder, '{}_tokens.pkl'.format(name)),
                     'rb')).tolist())
        token2lemma = {
            line.strip().split()[0]: line.strip().split()[1]
            for line in open(file_txt).readlines()
        }
        output_file = os.path.join(results_folder,
                                   '{}_output.tsv'.format(name))
        fw = open(output_file, 'w')

        mask = utils.build_mask_from_training_k(tokens, file_txt, labels_dict,
                                                input_semcor_k)
        gold_dictionary = {
            line.strip().split()[0]: line.strip().split()[1:]
            for line in open(
                os.path.join(gold_folder, '{}.gold.txt'.format(name)))
        }

        eval_out = torch.exp(model.eval()(x))
        eval_out *= mask.cuda()
        predicted = torch.argmax(eval_out, 2)
        to_pred_idx = torch.nonzero(process_labels(y, labels_dict)).cuda()
        c, t = 0, 0
        for a, b in to_pred_idx:
            instance_id = tokens[a][b]
            gold_label = gold_dictionary[instance_id]
            pred_label = reverse_labels_dict[predicted[a, b].item()]
            if pred_label in gold_label:
                c += 1
                fw.write('c\t')
            else:
                fw.write('w\t')
            t += 1
            scores_idx = torch.nonzero(eval_out[a, b])
            scored_ = [(reverse_labels_dict[idx.item()], eval_out[a, b,
                                                                  idx].item())
                       for idx in scores_idx]
            scored_labels = sorted(scored_, key=lambda x: x[1], reverse=True)
            fw.write(name + '.' + instance_id + '\t' +
                     token2lemma[instance_id] + '\t' + 'pred##' + pred_label +
                     '\t')
            fw.write(' '.join(['gold##' + gl for gl in gold_label]) + '\t')
            fw.write(' '.join(
                [label + '##' + str(score)
                 for label, score in scored_labels]) + '\n')

        if name != dev_name:
            cor += c
            tot += t

        f1 = np.round(c / t, 3)
        print(name, f1)
    print('ALL', np.round(cor / tot, 3))
Exemple #28
0
class Training(Base):
    '''
    Training the searched network
    cf: config.yml path
    cv_i: Which fold in the cross validation. If cv_i >= n_fold: use all the training dataset.
    for_train: If True, for training process, otherwise for searching.
    new_lr: if True, check_resume() will not load the saved states of optimizers and lr_schedulers.
    '''
    def __init__(self, cf='config.yml', cv_i=0, for_train=True, new_lr=False):
        super().__init__(cf=cf, cv_i=cv_i, for_train=for_train)
        self._init_model()
        self.check_resume(new_lr=new_lr)

    def _init_model(self):
        geno_file = os.path.join(self.log_path,
                                 self.config['search']['geno_file'])
        with open(geno_file, 'rb') as f:
            gene = eval(pickle.load(f)[0])
        self.model = SearchedNet(
            gene=gene,
            in_channels=self.config['data']['in_channels'],
            init_node_c=self.config['search']['init_node_c'],
            out_channels=self.config['data']['out_channels'],
            depth=self.config['search']['depth'],
            n_nodes=self.config['search']['n_nodes'],
            drop_rate=self.config['train']['drop_rate']).to(self.device)
        print('Param size = {:.3f} MB'.format(
            calc_param_size(self.model.parameters())))
        self.loss = nn.CrossEntropyLoss().to(self.device)

        self.optim = Adam(self.model.parameters())
        self.scheduler = ReduceLROnPlateau(self.optim,
                                           verbose=True,
                                           factor=0.5)

    def check_resume(self, new_lr=False):
        self.last_save = os.path.join(self.log_path,
                                      self.config['train']['last_save'])
        self.best_shot = os.path.join(self.log_path,
                                      self.config['train']['best_shot'])
        if os.path.exists(self.last_save):
            state_dicts = torch.load(self.last_save, map_location=self.device)
            self.epoch = state_dicts['epoch'] + 1
            self.history = state_dicts['history']
            self.model.load_state_dict(state_dicts['model_param'])
            if not new_lr:
                self.optim.load_state_dict(state_dicts['optim'])
                self.scheduler.load_state_dict(state_dicts['scheduler'])
            self.best_val_loss = state_dicts['best_loss']
        else:
            self.epoch = 0
            self.history = defaultdict(list)
            self.best_val_loss = float('inf')

    def main_run(self):
        n_epochs = self.config['train']['epochs']

        for epoch in range(n_epochs):
            is_best = False
            loss, acc1, acc5 = self.train()
            val_loss, val_acc1, val_acc5 = self.validate()
            self.scheduler.step(val_loss)
            self.history['loss'].append(loss)
            self.history['acc1'].append(acc1)
            self.history['acc5'].append(acc5)
            self.history['val_loss'].append(val_loss)
            self.history['val_acc1'].append(val_acc1)
            self.history['val_acc5'].append(val_acc5)
            if val_loss < self.best_val_loss:
                is_best = True
                self.best_val_loss = val_loss

            # Save what the current epoch ends up with.
            state_dicts = {
                'epoch': self.epoch,
                'history': self.history,
                'model_param': self.model.state_dict(),
                'optim': self.optim.state_dict(),
                'scheduler': self.scheduler.state_dict(),
                'best_loss': self.best_val_loss
            }
            torch.save(state_dicts, self.last_save)

            if is_best:
                shutil.copy(self.last_save, self.best_shot)

            self.epoch += 1
            if self.epoch > n_epochs:
                break

            if DEBUG_FLAG and epoch >= 1:
                break
        print('Training Finished.')
        return

    def train(self):
        '''
        Training | Training process
        '''
        self.model.train()
        n_steps = self.train_generator.steps_per_epoch
        sum_loss = 0
        sum_acc1 = 0
        sum_acc5 = 0
        with tqdm(self.train_generator.epoch(),
                  total=n_steps,
                  desc='Training | Epoch {} | Training'.format(
                      self.epoch)) as pbar:
            for step, (x, y_truth) in enumerate(pbar):
                x = torch.as_tensor(x, device=self.device, dtype=torch.float)
                y_truth = torch.as_tensor(y_truth,
                                          device=self.device,
                                          dtype=torch.long)

                self.optim.zero_grad()
                y_pred = self.model(x)
                loss = self.loss(y_pred, y_truth)
                sum_loss += loss.item()
                acc1, acc5 = accuracy(y_pred, y_truth, topk=(1, 5))
                sum_acc1 += acc1
                sum_acc5 += acc5
                loss.backward()
                self.optim.step()

                postfix = OrderedDict()
                postfix['Loss'] = round(sum_loss / (step + 1), 3)
                postfix['Top-1-Acc'] = round(sum_acc1 / (step + 1), 3)
                postfix['Top-5-Acc'] = round(sum_acc5 / (step + 1), 3)
                pbar.set_postfix(postfix)

                if DEBUG_FLAG and step >= 1:
                    break

        return [round(i / n_steps, 3) for i in [sum_loss, sum_acc1, sum_acc5]]

    def validate(self):
        '''
        Training | Validation process
        '''
        self.model.eval()
        n_steps = self.val_generator.steps_per_epoch
        sum_loss = 0
        sum_acc1 = 0
        sum_acc5 = 0
        with tqdm(self.val_generator.epoch(),
                  total=n_steps,
                  desc='Training | Epoch {} | Val'.format(self.epoch)) as pbar:
            for step, (x, y_truth) in enumerate(pbar):
                x = torch.as_tensor(x, device=self.device, dtype=torch.float)
                y_truth = torch.as_tensor(y_truth,
                                          device=self.device,
                                          dtype=torch.long)
                y_pred = self.model(x)
                loss = self.loss(y_pred, y_truth)
                sum_loss += loss.item()
                acc1, acc5 = accuracy(y_pred, y_truth, topk=(1, 5))
                sum_acc1 += acc1
                sum_acc5 += acc5

                postfix = OrderedDict()
                postfix['Loss'] = round(sum_loss / (step + 1), 3)
                postfix['Top-1-Acc'] = round(sum_acc1 / (step + 1), 3)
                postfix['Top-5-Acc'] = round(sum_acc5 / (step + 1), 3)
                pbar.set_postfix(postfix)

                if DEBUG_FLAG and step >= 1:
                    break
        return [round(i / n_steps, 3) for i in [sum_loss, sum_acc1, sum_acc5]]
Exemple #29
0
def train(args, model, enc=False):
    best_acc = 0

    #TODO: calculate weights by processing dataset histogram (now its being set by hand from the torch values)
    #create a loder to run all images and calculate histogram of labels, then create weight array using class balancing

    weight = torch.ones(NUM_CLASSES)
    if (enc):
        weight[0] = 2.3653597831726
        weight[1] = 4.4237880706787
        weight[2] = 2.9691488742828
        weight[3] = 5.3442072868347
        weight[4] = 5.2983593940735
        weight[5] = 5.2275490760803
        weight[6] = 5.4394111633301
        weight[7] = 5.3659925460815
        weight[8] = 3.4170460700989
        weight[9] = 5.2414722442627
        weight[10] = 4.7376127243042
        weight[11] = 5.2286224365234
        weight[12] = 5.455126285553
        weight[13] = 4.3019247055054
        weight[14] = 5.4264230728149
        weight[15] = 5.4331531524658
        weight[16] = 5.433765411377
        weight[17] = 5.4631009101868
        weight[18] = 5.3947434425354
    else:
        weight[0] = 2.8149201869965
        weight[1] = 6.9850029945374
        weight[2] = 3.7890393733978
        weight[3] = 9.9428062438965
        weight[4] = 9.7702074050903
        weight[5] = 9.5110931396484
        weight[6] = 10.311357498169
        weight[7] = 10.026463508606
        weight[8] = 4.6323022842407
        weight[9] = 9.5608062744141
        weight[10] = 7.8698215484619
        weight[11] = 9.5168733596802
        weight[12] = 10.373730659485
        weight[13] = 6.6616044044495
        weight[14] = 10.260489463806
        weight[15] = 10.287888526917
        weight[16] = 10.289801597595
        weight[17] = 10.405355453491
        weight[18] = 10.138095855713

    weight[19] = 0

    assert os.path.exists(
        args.datadir), "Error: datadir (dataset directory) could not be loaded"

    co_transform = MyCoTransform(enc, augment=True, height=args.height)  #1024)
    co_transform_val = MyCoTransform(enc, augment=False,
                                     height=args.height)  #1024)
    dataset_train = cityscapes(args.datadir, co_transform, 'train')
    dataset_val = cityscapes(args.datadir, co_transform_val, 'val')

    loader = DataLoader(dataset_train,
                        num_workers=args.num_workers,
                        batch_size=args.batch_size,
                        shuffle=True)
    loader_val = DataLoader(dataset_val,
                            num_workers=args.num_workers,
                            batch_size=args.batch_size,
                            shuffle=False)

    if args.cuda:
        weight = weight.cuda()
    criterion = CrossEntropyLoss2d(weight)
    print(type(criterion))

    savedir = f'../save/{args.savedir}'

    if (enc):
        automated_log_path = savedir + "/automated_log_encoder.txt"
        modeltxtpath = savedir + "/model_encoder.txt"
    else:
        automated_log_path = savedir + "/automated_log.txt"
        modeltxtpath = savedir + "/model.txt"

    if (not os.path.exists(automated_log_path)
        ):  #dont add first line if it exists
        with open(automated_log_path, "a") as myfile:
            myfile.write(
                "Epoch\t\tTrain-loss\t\tTest-loss\t\tTrain-IoU\t\tTest-IoU\t\tlearningRate"
            )

    with open(modeltxtpath, "w") as myfile:
        myfile.write(str(model))

    #TODO: reduce memory in first gpu: https://discuss.pytorch.org/t/multi-gpu-training-memory-usage-in-balance/4163/4        #https://github.com/pytorch/pytorch/issues/1893

    #optimizer = Adam(model.parameters(), 5e-4, (0.9, 0.999),  eps=1e-08, weight_decay=2e-4)     ## scheduler 1
    optimizer = Adam(model.parameters(),
                     5e-4, (0.9, 0.999),
                     eps=1e-08,
                     weight_decay=1e-4)  ## scheduler 2

    start_epoch = 1
    if args.resume:
        #Must load weights, optimizer, epoch and best value.
        if enc:
            filenameCheckpoint = savedir + '/checkpoint_enc.pth.tar'
        else:
            filenameCheckpoint = savedir + '/checkpoint.pth.tar'

        assert os.path.exists(
            filenameCheckpoint
        ), "Error: resume option was used but checkpoint was not found in folder"
        checkpoint = torch.load(filenameCheckpoint)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        best_acc = checkpoint['best_acc']
        print("=> Loaded checkpoint at epoch {})".format(checkpoint['epoch']))

    #scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5) # set up scheduler     ## scheduler 1
    lambda1 = lambda epoch: pow(
        (1 - ((epoch - 1) / args.num_epochs)), 0.9)  ## scheduler 2
    scheduler = lr_scheduler.LambdaLR(optimizer,
                                      lr_lambda=lambda1)  ## scheduler 2

    if args.visualize and args.steps_plot > 0:
        board = Dashboard(args.port)

    for epoch in range(start_epoch, args.num_epochs + 1):
        print("----- TRAINING - EPOCH", epoch, "-----")

        scheduler.step(epoch)  ## scheduler 2

        epoch_loss = []
        time_train = []

        doIouTrain = args.iouTrain
        doIouVal = args.iouVal

        if (doIouTrain):
            iouEvalTrain = iouEval(NUM_CLASSES)

        usedLr = 0
        for param_group in optimizer.param_groups:
            print("LEARNING RATE: ", param_group['lr'])
            usedLr = float(param_group['lr'])

        model.train()
        for step, (images, labels) in enumerate(loader):

            start_time = time.time()
            #print (labels.size())
            #print (np.unique(labels.numpy()))
            #print("labels: ", np.unique(labels[0].numpy()))
            #labels = torch.ones(4, 1, 512, 1024).long()
            if args.cuda:
                images = images.cuda()
                labels = labels.cuda()

            inputs = Variable(images)
            targets = Variable(labels)
            outputs = model(inputs, only_encode=enc)

            #print("targets", np.unique(targets[:, 0].cpu().data.numpy()))

            optimizer.zero_grad()
            loss = criterion(outputs, targets[:, 0])
            loss.backward()
            optimizer.step()

            epoch_loss.append(loss.item())
            time_train.append(time.time() - start_time)

            if (doIouTrain):
                #start_time_iou = time.time()
                iouEvalTrain.addBatch(
                    outputs.max(1)[1].unsqueeze(1).data, targets.data)
                #print ("Time to add confusion matrix: ", time.time() - start_time_iou)

            #print(outputs.size())
            if args.visualize and args.steps_plot > 0 and step % args.steps_plot == 0:
                start_time_plot = time.time()
                image = inputs[0].cpu().data
                #image[0] = image[0] * .229 + .485
                #image[1] = image[1] * .224 + .456
                #image[2] = image[2] * .225 + .406
                #print("output", np.unique(outputs[0].cpu().max(0)[1].data.numpy()))
                board.image(image, f'input (epoch: {epoch}, step: {step})')
                if isinstance(outputs, list):  #merge gpu tensors
                    board.image(
                        color_transform(
                            outputs[0][0].cpu().max(0)[1].data.unsqueeze(0)),
                        f'output (epoch: {epoch}, step: {step})')
                else:
                    board.image(
                        color_transform(
                            outputs[0].cpu().max(0)[1].data.unsqueeze(0)),
                        f'output (epoch: {epoch}, step: {step})')
                board.image(color_transform(targets[0].cpu().data),
                            f'target (epoch: {epoch}, step: {step})')
                print("Time to paint images: ", time.time() - start_time_plot)
            if args.steps_loss > 0 and step % args.steps_loss == 0:
                average = sum(epoch_loss) / len(epoch_loss)
                print(
                    f'loss: {average:0.4} (epoch: {epoch}, step: {step})',
                    "// Avg time/img: %.4f s" %
                    (sum(time_train) / len(time_train) / args.batch_size))

        average_epoch_loss_train = sum(epoch_loss) / len(epoch_loss)

        iouTrain = 0
        if (doIouTrain):
            iouTrain, iou_classes = iouEvalTrain.getIoU()
            iouStr = getColorEntry(iouTrain) + '{:0.2f}'.format(
                iouTrain * 100) + '\033[0m'
            print("EPOCH IoU on TRAIN set: ", iouStr, "%")

        #Validate on 500 val images after each epoch of training
        print("----- VALIDATING - EPOCH", epoch, "-----")
        model.eval()
        epoch_loss_val = []
        time_val = []

        if (doIouVal):
            iouEvalVal = iouEval(NUM_CLASSES)

        for step, (images, labels) in enumerate(loader_val):
            start_time = time.time()
            if args.cuda:
                images = images.cuda()
                labels = labels.cuda()

            inputs = Variable(
                images, volatile=True
            )  #volatile flag makes it free backward or outputs for eval
            targets = Variable(labels, volatile=True)
            outputs = model(inputs, only_encode=enc)

            loss = criterion(outputs, targets[:, 0])
            epoch_loss_val.append(loss.item())
            time_val.append(time.time() - start_time)

            #Add batch to calculate TP, FP and FN for iou estimation
            if (doIouVal):
                #start_time_iou = time.time()
                iouEvalVal.addBatch(
                    outputs.max(1)[1].unsqueeze(1).data, targets.data)
                #print ("Time to add confusion matrix: ", time.time() - start_time_iou)

            if args.visualize and args.steps_plot > 0 and step % args.steps_plot == 0:
                start_time_plot = time.time()
                image = inputs[0].cpu().data
                board.image(image, f'VAL input (epoch: {epoch}, step: {step})')
                if isinstance(outputs, list):  #merge gpu tensors
                    board.image(
                        color_transform(
                            outputs[0][0].cpu().max(0)[1].data.unsqueeze(0)),
                        f'VAL output (epoch: {epoch}, step: {step})')
                else:
                    board.image(
                        color_transform(
                            outputs[0].cpu().max(0)[1].data.unsqueeze(0)),
                        f'VAL output (epoch: {epoch}, step: {step})')
                board.image(color_transform(targets[0].cpu().data),
                            f'VAL target (epoch: {epoch}, step: {step})')
                print("Time to paint images: ", time.time() - start_time_plot)
            if args.steps_loss > 0 and step % args.steps_loss == 0:
                average = sum(epoch_loss_val) / len(epoch_loss_val)
                print(
                    f'VAL loss: {average:0.4} (epoch: {epoch}, step: {step})',
                    "// Avg time/img: %.4f s" %
                    (sum(time_val) / len(time_val) / args.batch_size))

        average_epoch_loss_val = sum(epoch_loss_val) / len(epoch_loss_val)
        #scheduler.step(average_epoch_loss_val, epoch)  ## scheduler 1   # update lr if needed

        iouVal = 0
        if (doIouVal):
            iouVal, iou_classes, accVal, acc_classes = iouEvalVal.getIoU()

            #IoU of 19 classes
            #print ("pole    : %.6f" % (iou_classes[0]*100.0), "%\t")
            print("road           : %.6f" % (iou_classes[0] * 100.0), "%\t")
            print("sidewalk       : %.6f" % (iou_classes[1] * 100.0), "%\t")
            print("building       : %.6f" % (iou_classes[2] * 100.0), "%\t")
            print("wall           : %.6f" % (iou_classes[3] * 100.0), "%\t")
            print("fence          : %.6f" % (iou_classes[4] * 100.0), "%\t")
            print("pole           : %.6f" % (iou_classes[5] * 100.0), "%\t")
            print("traffic light  : %.6f" % (iou_classes[6] * 100.0), "%\t")
            print("traffic sign   : %.6f" % (iou_classes[7] * 100.0), "%\t")
            print("vegetation     : %.6f" % (iou_classes[8] * 100.0), "%\t")
            print("terrain        : %.6f" % (iou_classes[9] * 100.0), "%\t")
            print("sky            : %.6f" % (iou_classes[10] * 100.0), "%\t")
            print("person         : %.6f" % (iou_classes[11] * 100.0), "%\t")
            print("rider          : %.6f" % (iou_classes[12] * 100.0), "%\t")
            print("car            : %.6f" % (iou_classes[13] * 100.0), "%\t")
            print("truck          : %.6f" % (iou_classes[14] * 100.0), "%\t")
            print("bus            : %.6f" % (iou_classes[15] * 100.0), "%\t")
            print("train          : %.6f" % (iou_classes[16] * 100.0), "%\t")
            print("motorcycle     : %.6f" % (iou_classes[17] * 100.0), "%\t")
            print("bicycle        : %.6f" % (iou_classes[18] * 100.0), "%\t")

            iouStr = getColorEntry(iouVal) + '{:0.2f}'.format(
                iouVal * 100) + '\033[0m'
            print("EPOCH IoU on VAL set: ", iouStr, "%")

            print("road           : %.6f" % (acc_classes[0] * 100.0), "%\t")
            print("sidewalk       : %.6f" % (acc_classes[1] * 100.0), "%\t")
            print("building       : %.6f" % (acc_classes[2] * 100.0), "%\t")
            print("wall           : %.6f" % (acc_classes[3] * 100.0), "%\t")
            print("fence          : %.6f" % (acc_classes[4] * 100.0), "%\t")
            print("pole           : %.6f" % (acc_classes[5] * 100.0), "%\t")
            print("traffic light  : %.6f" % (acc_classes[6] * 100.0), "%\t")
            print("traffic sign   : %.6f" % (acc_classes[7] * 100.0), "%\t")
            print("vegetation     : %.6f" % (acc_classes[8] * 100.0), "%\t")
            print("terrain        : %.6f" % (acc_classes[9] * 100.0), "%\t")
            print("sky            : %.6f" % (acc_classes[10] * 100.0), "%\t")
            print("person         : %.6f" % (acc_classes[11] * 100.0), "%\t")
            print("rider          : %.6f" % (acc_classes[12] * 100.0), "%\t")
            print("car            : %.6f" % (acc_classes[13] * 100.0), "%\t")
            print("truck          : %.6f" % (acc_classes[14] * 100.0), "%\t")
            print("bus            : %.6f" % (acc_classes[15] * 100.0), "%\t")
            print("train          : %.6f" % (acc_classes[16] * 100.0), "%\t")
            print("motorcycle     : %.6f" % (acc_classes[17] * 100.0), "%\t")
            print("bicycle        : %.6f" % (acc_classes[18] * 100.0), "%\t")

            accStr = getColorEntry(accVal) + '{:0.2f}'.format(
                accVal * 100) + '\033[0m'
            print("EPOCH ACC on VAL set: ", accStr, "%")

        # remember best valIoU and save checkpoint
        if iouVal == 0:
            current_acc = -average_epoch_loss_val
        else:
            current_acc = iouVal
        is_best = current_acc > best_acc
        best_acc = max(current_acc, best_acc)
        if enc:
            filenameCheckpoint = savedir + '/checkpoint_enc.pth.tar'
            filenameBest = savedir + '/model_best_enc.pth.tar'
        else:
            filenameCheckpoint = savedir + '/checkpoint.pth.tar'
            filenameBest = savedir + '/model_best.pth.tar'
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': str(model),
                'state_dict': model.state_dict(),
                'best_acc': best_acc,
                'optimizer': optimizer.state_dict(),
            }, is_best, filenameCheckpoint, filenameBest)

        #SAVE MODEL AFTER EPOCH
        if (enc):
            filename = f'{savedir}/model_encoder-{epoch:03}.pth'
            filenamebest = f'{savedir}/model_encoder_best.pth'
        else:
            filename = f'{savedir}/model-{epoch:03}.pth'
            filenamebest = f'{savedir}/model_best.pth'
        if args.epochs_save > 0 and step > 0 and step % args.epochs_save == 0:
            torch.save(model.state_dict(), filename)
            print(f'save: {filename} (epoch: {epoch})')
        if (is_best):
            torch.save(model.state_dict(), filenamebest)
            print(f'save: {filenamebest} (epoch: {epoch})')
            if (not enc):
                with open(savedir + "/best.txt", "w") as myfile:
                    myfile.write("Best epoch is %d, with Val-IoU= %.4f" %
                                 (epoch, iouVal))
            else:
                with open(savedir + "/best_encoder.txt", "w") as myfile:
                    myfile.write("Best epoch is %d, with Val-IoU= %.4f" %
                                 (epoch, iouVal))

        #SAVE TO FILE A ROW WITH THE EPOCH RESULT (train loss, val loss, train IoU, val IoU)
        #Epoch		Train-loss		Test-loss	Train-IoU	Test-IoU		learningRate
        with open(automated_log_path, "a") as myfile:
            myfile.write("\n%d\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.8f" %
                         (epoch, average_epoch_loss_train,
                          average_epoch_loss_val, iouTrain, iouVal, usedLr))

    return (model)  #return model (convenience for encoder-decoder training)
Exemple #30
0
def main_single(rank, FLAGS):
    rank_idx = FLAGS.node_rank * FLAGS.gpus + rank
    world_size = FLAGS.nodes * FLAGS.gpus

    if world_size > 1:
        dist.init_process_group(backend='nccl',
                                init_method='tcp://localhost:1492',
                                world_size=world_size,
                                rank=rank_idx)

    if FLAGS.cuda:
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    logdir = osp.join(FLAGS.logdir, FLAGS.exp)
    if not osp.exists(logdir):
        os.makedirs(logdir)
    # logger = TensorBoardOutputFormat(logdir)

    ## dataset
    dataset_train = RobotDatasetGrasp('train')
    dataset_test = RobotDatasetGrasp('test')

    train_dataloader = DataLoader(dataset_train,
                                  num_workers=FLAGS.data_workers,
                                  batch_size=FLAGS.batch_size,
                                  shuffle=True,
                                  pin_memory=False,
                                  drop_last=False)
    test_dataloader = DataLoader(dataset_test,
                                 num_workers=FLAGS.data_workers,
                                 batch_size=FLAGS.batch_size,
                                 shuffle=False,
                                 pin_memory=False,
                                 drop_last=False)

    ## model
    model = EnergyModel().to(device)
    optimizer = Adam(model.parameters(), lr=FLAGS.lr, betas=(0.9, 0.999))

    if FLAGS.resume_iter != 0:
        model_path = osp.join(logdir, "model_{}".format(FLAGS.resume_iter))
        checkpoint = torch.load(model_path)
        FLAGS_OLD = checkpoint['FLAGS']

        try:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            model.load_state_dict(checkpoint['model_state_dict'])
        except:
            model_state_dict = {
                k.replace("module.", ""): v
                for k, v in checkpoint['model_state_dict'].items()
            }

    if FLAGS.gpus > 1:
        sync_model(model)

    if FLAGS.train:
        model = model.train()
        train(train_dataloader, test_dataloader, model, optimizer, FLAGS,
              logdir, rank_idx)
    else:
        model = model.eval()
        test(test_dataloader, model, FLAGS)
Exemple #31
0
class Trainer():
    def __init__(self, train_dataloader, test_dataloader, lr, betas, weight_decay, log_freq, with_cuda, model=None):
        
        cuda_condition = torch.cuda.is_available() and with_cuda
        self.device = torch.device("cuda" if cuda_condition else "cpu")
        print("Use:", "cuda:0" if cuda_condition else "cpu")
        
        self.model = cnn_audio().to(self.device)
        self.optim = Adam(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay)
        self.scheduler = lr_scheduler.CosineAnnealingLR(self.optim, 5)
        self.criterion = nn.BCEWithLogitsLoss()
        
        if model != None:            
            checkpoint = torch.load(model)
            self.model.load_state_dict(checkpoint['model_state_dict'])
            self.optim.load_state_dict(checkpoint['optimizer_state_dict'])
            self.epoch = checkpoint['epoch']
            self.criterion = checkpoint['loss']


        if torch.cuda.device_count() > 1:
            self.model = nn.DataParallel(self.model)
        print("Using %d GPUS for Converter" % torch.cuda.device_count())
        
        self.train_data = train_dataloader
        self.test_data = test_dataloader
        
        self.log_freq = log_freq
        print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()]))
        
        self.test_loss = []
        self.train_loss = []
        self.train_f1_score = []
        self.test_f1_score = []
    
    def train(self, epoch):
        self.iteration(epoch, self.train_data)

    def test(self, epoch):
        self.iteration(epoch, self.test_data, train=False)

    def iteration(self, epoch, data_loader, train=True):
        """
        :param epoch: 現在のepoch
        :param data_loader: torch.utils.data.DataLoader
        :param train: trainかtestかのbool値
        """
        str_code = "train" if train else "test"

        data_iter = tqdm(enumerate(data_loader), desc="EP_%s:%d" % (str_code, epoch), total=len(data_loader), bar_format="{l_bar}{r_bar}")
        
        total_element = 0
        loss_store = 0.0
        f1_score_store = 0.0
        total_correct = 0

        for i, data in data_iter:
            specgram = data[0].to(self.device)
            label = data[2].to(self.device)
            one_hot_label = data[1].to(self.device)
            predict_label = self.model(specgram)

            # 
            predict_f1_score = get_F1_score(
                label.cpu().detach().numpy(),
                convert_label(predict_label.cpu().detach().numpy()),
                average='micro'
            )
            
            loss = self.criterion(predict_label, one_hot_label)

            # 
            if train:
                self.optim.zero_grad()
                loss.backward()
                self.optim.step()
                self.scheduler.step()

            loss_store += loss.item()
            f1_score_store += predict_f1_score
            self.avg_loss = loss_store / (i + 1)
            self.avg_f1_score = f1_score_store / (i + 1)
        
            post_fix = {
                "epoch": epoch,
                "iter": i,
                "avg_loss": round(self.avg_loss, 5),
                "loss": round(loss.item(), 5),
                "avg_f1_score": round(self.avg_f1_score, 5)
            }

        data_iter.write(str(post_fix))
        self.train_loss.append(self.avg_loss) if train else self.test_loss.append(self.avg_loss)
        self.train_f1_score.append(self.avg_f1_score) if train else self.test_f1_score.append(self.avg_f1_score)
        
    
    def save(self, epoch, file_path="../models/2f/"):
        """
        """
        output_path = file_path + f"crnn_ep{epoch}.model"
        torch.save(
            {
            'epoch': epoch,
            'model_state_dict': self.model.cpu().state_dict(),
            'optimizer_state_dict': self.optim.state_dict(),
            'criterion': self.criterion
            },
            output_path)
        self.model.to(self.device)
        print("EP:%d Model Saved on:" % epoch, output_path)
        return output_path
    
    def export_log(self, epoch, file_path="../../logs/2f/"):
        df = pd.DataFrame({
            "train_loss": self.train_loss, 
            "test_loss": self.test_loss, 
            "train_F1_score": self.train_f1_score,
            "test_F1_score": self.test_f1_score
        })
        output_path = file_path+f"loss_{epoch}.log"
        print("EP:%d logs Saved on:" % epoch, output_path)
        df.to_csv(output_path)
    def fit(self, model, feature_extraction, protocol, log_dir, subset='train',
            epochs=1000, restart=0, gpu=False):

        import tensorboardX
        writer = tensorboardX.SummaryWriter(log_dir=log_dir)

        checkpoint = Checkpoint(log_dir=log_dir,
                                      restart=restart > 0)

        batch_generator = SpeechSegmentGenerator(
            feature_extraction,
            per_label=self.per_label, per_fold=self.per_fold,
            duration=self.duration, parallel=self.parallel)
        batches = batch_generator(protocol, subset=subset)
        batch = next(batches)

        batches_per_epoch = batch_generator.batches_per_epoch

        if restart > 0:
            weights_pt = checkpoint.WEIGHTS_PT.format(
                log_dir=log_dir, epoch=restart)
            model.load_state_dict(torch.load(weights_pt))

        if gpu:
            model = model.cuda()

        model.internal = False

        parameters = list(model.parameters())

        if self.variant in [2, 3, 4, 5, 6, 7, 8]:

            # norm batch-normalization
            self.norm_bn = nn.BatchNorm1d(
                1, eps=1e-5, momentum=0.1, affine=True)
            if gpu:
                self.norm_bn = self.norm_bn.cuda()
            parameters += list(self.norm_bn.parameters())

        if self.variant in [9]:
            # norm batch-normalization
            self.norm_bn = nn.BatchNorm1d(
                1, eps=1e-5, momentum=0.1, affine=False)
            if gpu:
                self.norm_bn = self.norm_bn.cuda()
            parameters += list(self.norm_bn.parameters())

        if self.variant in [5, 6, 7]:
            self.positive_bn = nn.BatchNorm1d(
                1, eps=1e-5, momentum=0.1, affine=False)
            self.negative_bn = nn.BatchNorm1d(
                1, eps=1e-5, momentum=0.1, affine=False)
            if gpu:
                self.positive_bn = self.positive_bn.cuda()
                self.negative_bn = self.negative_bn.cuda()
            parameters += list(self.positive_bn.parameters())
            parameters += list(self.negative_bn.parameters())

        if self.variant in [8, 9]:

            self.delta_bn = nn.BatchNorm1d(
                1, eps=1e-5, momentum=0.1, affine=False)
            if gpu:
                self.delta_bn = self.delta_bn.cuda()
            parameters += list(self.delta_bn.parameters())

        optimizer = Adam(parameters)
        if restart > 0:
            optimizer_pt = checkpoint.OPTIMIZER_PT.format(
                log_dir=log_dir, epoch=restart)
            optimizer.load_state_dict(torch.load(optimizer_pt))
            if gpu:
                for state in optimizer.state.values():
                    for k, v in state.items():
                        if torch.is_tensor(v):
                            state[k] = v.cuda()

        epoch = restart if restart > 0 else -1
        while True:
            epoch += 1
            if epoch > epochs:
                break

            loss_avg, tloss_avg, closs_avg = 0., 0., 0.

            if epoch % 5 == 0:
                log_positive = []
                log_negative = []
                log_delta = []
                log_norm = []

            desc = 'Epoch #{0}'.format(epoch)
            for i in tqdm(range(batches_per_epoch), desc=desc):

                model.zero_grad()

                batch = next(batches)

                X = batch['X']
                if not getattr(model, 'batch_first', True):
                    X = np.rollaxis(X, 0, 2)
                X = np.array(X, dtype=np.float32)
                X = Variable(torch.from_numpy(X))

                if gpu:
                    X = X.cuda()

                fX = model(X)

                # pre-compute pairwise distances
                distances = self.pdist(fX)

                # sample triplets
                triplets = getattr(self, 'batch_{0}'.format(self.sampling))
                anchors, positives, negatives = triplets(batch['y'], distances)

                # compute triplet loss
                tlosses, deltas, pos_index, neg_index  = self.triplet_loss(
                    distances, anchors, positives, negatives,
                    return_delta=True)

                tloss = torch.mean(tlosses)

                if self.variant == 1:

                    closses = F.sigmoid(
                        F.softsign(deltas) * torch.norm(fX[anchors], 2, 1, keepdim=True))

                    # if d(a, p) < d(a, n) (i.e. good case)
                    #   --> sign(delta) < 0
                    #   --> loss decreases when norm increases.
                    #       i.e. encourages longer anchor

                    # if d(a, p) > d(a, n) (i.e. bad case)
                    #   --> sign(delta) > 0
                    #   --> loss increases when norm increases
                    #       i.e. encourages shorter anchor

                elif self.variant == 2:

                    norms_ = torch.norm(fX, 2, 1, keepdim=True)
                    norms_ = F.sigmoid(self.norm_bn(norms_))

                    confidence = (norms_[anchors] + norms_[positives] + norms_[negatives]) / 3
                    # if |x| is average
                    #    --> normalized |x| = 0
                    #    --> confidence = 0.5

                    # if |x| is bigger than average
                    #    --> normalized |x| >> 0
                    #    --> confidence = 1

                    # if |x| is smaller than average
                    #    --> normalized |x| << 0
                    #    --> confidence = 0

                    correctness = F.sigmoid(-deltas / np.pi * 6)
                    # if d(a, p) = d(a, n) (i.e. uncertain case)
                    #    --> correctness = 0.5

                    # if d(a, p) - d(a, n) = -𝛑 (i.e. best possible case)
                    #    --> correctness = 1

                    # if d(a, p) - d(a, n) = +𝛑 (i.e. worst possible case)
                    #    --> correctness = 0

                    closses = torch.abs(confidence - correctness)
                    # small if (and only if) confidence & correctness agree

                elif self.variant == 3:

                    norms_ = torch.norm(fX, 2, 1, keepdim=True)
                    norms_ = F.sigmoid(self.norm_bn(norms_))
                    confidence = (norms_[anchors] * norms_[positives] * norms_[negatives]) / 3

                    correctness = F.sigmoid(-(deltas + np.pi / 4) / np.pi * 6)
                    # correctness = 0.5 at delta == -pi/4
                    # correctness = 1 for delta == -pi
                    # correctness = 0 for delta < 0

                    closses = torch.abs(confidence - correctness)

                elif self.variant == 4:

                    norms_ = torch.norm(fX, 2, 1, keepdim=True)
                    norms_ = F.sigmoid(self.norm_bn(norms_))
                    confidence = (norms_[anchors] * norms_[positives] * norms_[negatives]) ** 1/3

                    correctness = F.sigmoid(-(deltas + np.pi / 4) / np.pi * 6)
                    # correctness = 0.5 at delta == -pi/4
                    # correctness = 1 for delta == -pi
                    # correctness = 0 for delta < 0

                    # delta = pos - neg ... should be < 0

                    closses = torch.abs(confidence - correctness)

                elif self.variant == 5:

                    norms_ = torch.norm(fX, 2, 1, keepdim=True)
                    confidence = F.sigmoid(self.norm_bn(norms_))

                    confidence_pos = .5 * (confidence[anchors] + confidence[positives])
                    # low positive distance == high correctness
                    correctness_pos = F.sigmoid(
                        -self.positive_bn(distances[pos_index].view(-1, 1)))

                    confidence_neg = .5 * (confidence[anchors] + confidence[negatives])
                    # high negative distance == high correctness
                    correctness_neg = F.sigmoid(
                        self.negative_bn(distances[neg_index].view(-1, 1)))

                    closses = .5 * (torch.abs(confidence_pos - correctness_pos) \
                                  + torch.abs(confidence_neg - correctness_neg))

                elif self.variant == 6:

                    norms_ = torch.norm(fX, 2, 1, keepdim=True)
                    confidence = F.sigmoid(self.norm_bn(norms_))

                    confidence_pos = .5 * (confidence[anchors] + confidence[positives])
                    # low positive distance == high correctness
                    correctness_pos = F.sigmoid(
                        -self.positive_bn(distances[pos_index].view(-1, 1)))

                    closses = torch.abs(confidence_pos - correctness_pos)

                elif self.variant == 7:

                    norms_ = torch.norm(fX, 2, 1, keepdim=True)
                    confidence = F.sigmoid(self.norm_bn(norms_))

                    confidence_neg = .5 * (confidence[anchors] + confidence[negatives])
                    # high negative distance == high correctness
                    correctness_neg = F.sigmoid(
                        self.negative_bn(distances[neg_index].view(-1, 1)))

                    closses = torch.abs(confidence_neg - correctness_neg)

                elif self.variant in [8, 9]:

                    norms_ = torch.norm(fX, 2, 1, keepdim=True)
                    norms_ = F.sigmoid(self.norm_bn(norms_))
                    confidence = (norms_[anchors] * norms_[positives] * norms_[negatives]) / 3

                    correctness = F.sigmoid(-self.delta_bn(deltas))
                    closses = torch.abs(confidence - correctness)

                closs = torch.mean(closses)

                if epoch % 5 == 0:

                    if gpu:
                        fX_npy = fX.data.cpu().numpy()
                        pdist_npy = distances.data.cpu().numpy()
                        delta_npy = deltas.data.cpu().numpy()
                    else:
                        fX_npy = fX.data.numpy()
                        pdist_npy = distances.data.numpy()
                        delta_npy = deltas.data.numpy()

                    log_norm.append(np.linalg.norm(fX_npy, axis=1))

                    same_speaker = pdist(batch['y'].reshape((-1, 1)), metric='chebyshev') < 1
                    log_positive.append(pdist_npy[np.where(same_speaker)])
                    log_negative.append(pdist_npy[np.where(~same_speaker)])

                    log_delta.append(delta_npy)

                # log loss
                if gpu:
                    tloss_ = float(tloss.data.cpu().numpy())
                    closs_ = float(closs.data.cpu().numpy())
                else:
                    tloss_ = float(tloss.data.numpy())
                    closs_ = float(closs.data.numpy())
                tloss_avg += tloss_
                closs_avg += closs_
                loss_avg += tloss_ + closs_

                loss = tloss + closs
                loss.backward()
                optimizer.step()

            tloss_avg /= batches_per_epoch
            writer.add_scalar('tloss', tloss_avg, global_step=epoch)

            closs_avg /= batches_per_epoch
            writer.add_scalar('closs', closs_avg, global_step=epoch)

            loss_avg /= batches_per_epoch
            writer.add_scalar('loss', loss_avg, global_step=epoch)

            if epoch % 5 == 0:

                log_positive = np.hstack(log_positive)
                writer.add_histogram(
                    'embedding/pairwise_distance/positive', log_positive,
                    global_step=epoch, bins=np.linspace(0, np.pi, 50))
                log_negative = np.hstack(log_negative)

                writer.add_histogram(
                    'embedding/pairwise_distance/negative', log_negative,
                    global_step=epoch, bins=np.linspace(0, np.pi, 50))

                _, _, _, eer = det_curve(
                    np.hstack([np.ones(len(log_positive)), np.zeros(len(log_negative))]),
                    np.hstack([log_positive, log_negative]), distances=True)
                writer.add_scalar('eer', eer, global_step=epoch)

                log_norm = np.hstack(log_norm)
                writer.add_histogram(
                    'norm', log_norm,
                    global_step=epoch, bins='doane')

                log_delta = np.vstack(log_delta)
                writer.add_histogram(
                    'delta', log_delta,
                    global_step=epoch, bins='doane')

            checkpoint.on_epoch_end(epoch, model, optimizer)

            if hasattr(self, 'norm_bn'):
                confidence_pt = self.CONFIDENCE_PT.format(
                    log_dir=log_dir, epoch=epoch)
                torch.save(self.norm_bn.state_dict(), confidence_pt)
    def fit(self, model, feature_extraction, protocol, log_dir, subset='train',
            epochs=1000, restart=None, gpu=False):

        import tensorboardX
        writer = tensorboardX.SummaryWriter(log_dir=log_dir)

        checkpoint = Checkpoint(
            log_dir=log_dir, restart=(False if restart is None else True))
        try:
            batch_generator = SpeechSegmentGenerator(
                feature_extraction,
                per_label=self.per_label, per_fold=self.per_fold,
                duration=self.duration)
            batches = batch_generator(protocol, subset=subset)
            batch = next(batches)
        except OSError as e:
            del batch_generator.data_
            batch_generator = SpeechSegmentGenerator(
                feature_extraction,
                per_label=self.per_label, per_fold=self.per_fold,
                duration=self.duration, fast=False)
            batches = batch_generator(protocol, subset=subset)
            batch = next(batches)

        # one minute per speaker
        duration_per_epoch = 60. * batch_generator.n_labels
        duration_per_batch = self.duration * batch_generator.n_sequences_per_batch
        batches_per_epoch = int(np.ceil(duration_per_epoch / duration_per_batch))

        if restart is not None:
            weights_pt = checkpoint.WEIGHTS_PT.format(
                log_dir=log_dir, epoch=restart)
            model.load_state_dict(torch.load(weights_pt))

        if gpu:
            model = model.cuda()

        model.internal = False

        n_domains = len(batch_generator.domains_[self.domain])
        if n_domains < 2:
            raise ValueError('There must be more than one domain.')

        domain_clf = DomainClassifier(model.output_dim, n_domains, alpha=1.)
        if gpu:
            domain_clf = domain_clf.cuda()

        domain_loss = nn.CrossEntropyLoss()

        optimizer = Adam(list(model.parameters()) + list(domain_clf.parameters()))
        if restart is not None:
            optimizer_pt = checkpoint.OPTIMIZER_PT.format(
                log_dir=log_dir, epoch=restart)
            optimizer.load_state_dict(torch.load(optimizer_pt))
            if gpu:
                for state in optimizer.state.values():
                    for k, v in state.items():
                        if torch.is_tensor(v):
                            state[k] = v.cuda()

        restart = 0 if restart is None else restart + 1
        for epoch in range(restart, restart + epochs):

            tloss_avg = 0.
            dloss_avg = 0.
            loss_avg = 0.
            dacc_avg = 0.
            positive, negative = [], []
            if not model.normalize:
                norms = []

            desc = 'Epoch #{0}'.format(epoch)
            for i in tqdm(range(batches_per_epoch), desc=desc):

                model.zero_grad()

                batch = next(batches)

                X = batch['X']
                if not getattr(model, 'batch_first', True):
                    X = np.rollaxis(X, 0, 2)
                X = np.array(X, dtype=np.float32)
                X = Variable(torch.from_numpy(X))

                y = batch['y']
                y_domain = batch['y_{domain}'.format(domain=self.domain)]

                if gpu:
                    X = X.cuda()
                fX = model(X)

                if not model.normalize:
                    if gpu:
                        fX_ = fX.data.cpu().numpy()
                    else:
                        fX_ = fX.data.numpy()
                    norms.append(np.linalg.norm(fX_, axis=0))

                triplet_losses = []
                for d, domain in enumerate(np.unique(y_domain)):

                    this_domain = np.where(y_domain == domain)[0]

                    domain_y = y[this_domain]

                    # if there is less than 2 speakers in this domain, skip it
                    if len(np.unique(domain_y)) < 2:
                        continue

                    # pre-compute within-domain pairwise distances
                    domain_fX = fX[this_domain, :]
                    domain_pdist = self.pdist(domain_fX)

                    if gpu:
                        domain_pdist_ = domain_pdist.data.cpu().numpy()
                    else:
                        domain_pdist_ = domain_pdist.data.numpy()
                    is_positive = pdist(domain_y.reshape((-1, 1)), metric='chebyshev') < 1
                    positive.append(domain_pdist_[np.where(is_positive)])
                    negative.append(domain_pdist_[np.where(~is_positive)])

                    # sample triplets
                    if self.sampling == 'all':
                        anchors, positives, negatives = self.batch_all(
                            domain_y, domain_pdist)

                    elif self.sampling == 'hard':
                        anchors, positives, negatives = self.batch_hard(
                            domain_y, domain_pdist)

                    # compute triplet loss
                    triplet_losses.append(
                        self.triplet_loss(domain_pdist, anchors,
                                          positives, negatives))

                tloss = 0.
                for tl in triplet_losses:
                    tloss += torch.mean(tl)
                tloss /= len(triplet_losses)

                if gpu:
                    tloss_ = float(tloss.data.cpu().numpy())
                else:
                    tloss_ = float(tloss.data.numpy())
                tloss_avg += tloss_

                # domain-adversarial
                y_domain = Variable(torch.from_numpy(np.array(y_domain)))
                if gpu:
                    y_domain = y_domain.cuda()

                domain_score = domain_clf(fX)
                dloss = domain_loss(domain_score, y_domain)

                if gpu:
                    dloss_ = float(dloss.data.cpu().numpy())
                else:
                    dloss_ = float(dloss.data.numpy())
                dloss_avg += dloss_

                # log domain classification accuracy
                if gpu:
                    domain_score_ = domain_score.data.cpu().numpy()
                    y_domain_ = y_domain.data.cpu().numpy()
                else:
                    domain_score_ = domain_score.data.numpy()
                    y_domain_ = y_domain.data.numpy()
                dacc_ = np.mean(np.argmax(domain_score_, axis=1) == y_domain_)
                dacc_avg += dacc_

                loss = tloss + dloss

                if gpu:
                    loss_ = float(loss.data.cpu().numpy())
                else:
                    loss_ = float(loss.data.numpy())
                loss_avg += loss_
                loss_avg += loss_

                loss.backward()
                optimizer.step()

            # if gpu:
            #     embeddings = fX.data.cpu()
            # else:
            #     embeddings = fX.data
            # metadata = list(batch['extra'][self.domain])
            #
            # writer.add_embedding(embeddings, metadata=metadata,
            #                      global_step=epoch)

            tloss_avg /= batches_per_epoch
            writer.add_scalar('tloss', tloss_avg, global_step=epoch)
            dloss_avg /= batches_per_epoch
            writer.add_scalar('dloss', dloss_avg, global_step=epoch)
            loss_avg /= batches_per_epoch
            writer.add_scalar('loss', loss_avg, global_step=epoch)
            dacc_avg /= batches_per_epoch
            writer.add_scalar('dacc', dacc_avg, global_step=epoch)

            positive = np.hstack(positive)
            negative = np.hstack(negative)
            writer.add_histogram(
                'embedding/pairwise_distance/positive', positive,
                global_step=epoch, bins='auto')
            writer.add_histogram(
                'embedding/pairwise_distance/negative', negative,
                global_step=epoch, bins='auto')

            if not model.normalize:
                norms = np.hstack(norms)
                writer.add_histogram(
                    'embedding/norm', norms,
                    global_step=epoch, bins='auto')

            checkpoint.on_epoch_end(epoch, model, optimizer)