Пример #1
0
    def __init__(self, state_space, act_n, quantile_dim, num_quantiles,
                 hidden_dim, num_hidden, optim_params):
        """
        Rainbow Recurrent IQN

        IQN: https://arxiv.org/pdf/1806.06923.pdf
        R2D2: https://openreview.net/pdf?id=r1lyTjAqYX
        R2D3: https://arxiv.org/abs/1909.01387
        """
        nn.Module.__init__(self)

        self.online = Model(state_space, act_n, quantile_dim, num_quantiles,
                            hidden_dim, num_hidden)
        self.target = deepcopy(self.online)

        self.loss_func = nn.SmoothL1Loss(reduction="mean")
        self.optim = RAdam(self.online.parameters(), **optim_params)
Пример #2
0
elif args.arch == 'stacked':
    model = stacked_transformer_model
else:
    raise TypeError

if args.optimizer.lower() == 'adam':
    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           betas=(args.beta1, args.beta2))
elif args.optimizer.lower() == 'sgd':
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=args.momentum)
elif args.optimizer.lower() == 'radam':
    optimizer = RAdam(model.parameters(),
                      lr=args.lr,
                      betas=(args.beta1, args.beta2))
else:
    raise TypeError
iterator = BucketIterator(batch_size=args.batch,
                          sorting_keys=[("source", "num_tokens")])
iterator.index_with(vocab)

#scheduler = _PyTorchLearningRateSchedulerWrapper(ReduceLROnPlateau(optimizer, patience=4))

if torch.cuda.is_available():
    cuda_device = 0
    model = model.cuda(cuda_device)
    print('using gpu')
else:
    cuda_device = -1
Пример #3
0
    def __init__(self,
                 encoder,
                 decoder,
                 optimizer_params={},
                 amp_params={},
                 n_jobs=0,
                 rank=0):

        lr = optimizer_params.get('lr', 1e-3)
        weight_decay = optimizer_params.get('weight_decay', 0)
        warmap = optimizer_params.get('warmap', 100)
        amsgrad = optimizer_params.get('amsgrad', False)
        opt_level = amp_params.get('opt_level', 'O0')
        loss_scale = amp_params.get('loss_scale', None)

        self.device = torch.device('cuda:' + str(rank))
        self.encoder = encoder.to(self.device)
        #self.decoder = decoder.to(self.device)
        self.num_classes = decoder.num_classes
        self.mse_critetion = nn.L1Loss()
        self.ce_criterion = LabelSmoothingLoss(self.num_classes,
                                               smoothing=0.1,
                                               reduction='none').to(
                                                   self.device)
        self.vat_criterion = VATLoss()
        self.cutmix = CutMix(self.num_classes)

        param_optimizer = list(self.encoder.named_parameters()
                               )  #+ list(self.decoder.named_parameters())
        no_decay = ['bn', 'bias']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            weight_decay
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]

        self.optimizer = RAdam(optimizer_grouped_parameters,
                               lr=lr,
                               weight_decay=weight_decay)

        self.is_master = torch.distributed.get_rank() == 0
        torch.cuda.set_device(rank)
        [self.encoder
         ], self.optimizer = apex.amp.initialize([self.encoder],
                                                 self.optimizer,
                                                 opt_level=opt_level,
                                                 loss_scale=loss_scale,
                                                 verbosity=1)

        self.scheduler = StepLR(self.optimizer, step_size=20, gamma=0.5)

        self.encoder = apex.parallel.DistributedDataParallel(
            self.encoder, delay_allreduce=True)
        #self.decoder = apex.parallel.DistributedDataParallel(self.decoder, delay_allreduce=True)

        self.last_epoch = 0
        self.n_jobs = n_jobs
Пример #4
0
class Trainer:
    def __init__(self,
                 encoder,
                 decoder,
                 optimizer_params={},
                 amp_params={},
                 n_jobs=0,
                 rank=0):

        lr = optimizer_params.get('lr', 1e-3)
        weight_decay = optimizer_params.get('weight_decay', 0)
        warmap = optimizer_params.get('warmap', 100)
        amsgrad = optimizer_params.get('amsgrad', False)
        opt_level = amp_params.get('opt_level', 'O0')
        loss_scale = amp_params.get('loss_scale', None)

        self.device = torch.device('cuda:' + str(rank))
        self.encoder = encoder.to(self.device)
        #self.decoder = decoder.to(self.device)
        self.num_classes = decoder.num_classes
        self.mse_critetion = nn.L1Loss()
        self.ce_criterion = LabelSmoothingLoss(self.num_classes,
                                               smoothing=0.1,
                                               reduction='none').to(
                                                   self.device)
        self.vat_criterion = VATLoss()
        self.cutmix = CutMix(self.num_classes)

        param_optimizer = list(self.encoder.named_parameters()
                               )  #+ list(self.decoder.named_parameters())
        no_decay = ['bn', 'bias']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            weight_decay
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]

        self.optimizer = RAdam(optimizer_grouped_parameters,
                               lr=lr,
                               weight_decay=weight_decay)

        self.is_master = torch.distributed.get_rank() == 0
        torch.cuda.set_device(rank)
        [self.encoder
         ], self.optimizer = apex.amp.initialize([self.encoder],
                                                 self.optimizer,
                                                 opt_level=opt_level,
                                                 loss_scale=loss_scale,
                                                 verbosity=1)

        self.scheduler = StepLR(self.optimizer, step_size=20, gamma=0.5)

        self.encoder = apex.parallel.DistributedDataParallel(
            self.encoder, delay_allreduce=True)
        #self.decoder = apex.parallel.DistributedDataParallel(self.decoder, delay_allreduce=True)

        self.last_epoch = 0
        self.n_jobs = n_jobs

    def _train_epoch(self, train_dataloader):
        if self.is_master:
            pbar = tqdm(desc=f'Train, epoch #{self.last_epoch}',
                        total=len(train_dataloader))

        self.encoder.train()
        #self.decoder.train()

        sum_loss, cls_loss = AvgMeter(), AvgMeter()
        for images, labels in train_dataloader:
            images, labels, shuffled_labels, l = self.cutmix(images, labels)
            images = images.to(self.device)
            labels = labels.to(self.device)
            shuffled_labels = shuffled_labels.to(self.device)
            l = l.to(self.device)

            self.optimizer.zero_grad()

            #loss_vat = self.vat_criterion(self.encoder, images)

            label_preds = self.encoder(images)
            #reconsts_l = self.decoder(latents, labels)
            #with disable_tracking_bn_stats(self.encoder):
            #    latents_l, label_preds_l = self.encoder(reconsts_l)
            #labels_r = torch.randint_like(labels, low=0, high=self.num_classes)
            #reconsts_r = self.decoder(latents, labels_r)
            #with disable_tracking_bn_stats(self.encoder):
            #    latents_r, label_preds_r = self.encoder(reconsts_r)

            loss_c = (l * self.ce_criterion(label_preds, labels) + (1 - l) *
                      self.ce_criterion(label_preds, shuffled_labels)).mean()
            #loss_r = self.mse_critetion(reconsts_l, images)
            #loss_e = self.ce_criterion(label_preds_r, labels_r)
            #loss_i = self.mse_critetion(latents_l, latents_r)

            losses = loss_c  #+ loss_vat # + loss_r + loss_e + loss_i

            with apex.amp.scale_loss(losses, self.optimizer) as scaled_loss:
                scaled_loss.backward()

            self.optimizer.step()

            sum_loss.update(losses.item())
            cls_loss.update(loss_c.item())

            info_tensor = torch.tensor([sum_loss(), cls_loss()],
                                       device=self.device)
            torch.distributed.reduce(info_tensor, dst=0)

            if self.is_master:
                info_tensor = info_tensor / torch.distributed.get_world_size()
                pbar.update(1)
                pbar.set_postfix({
                    'sum_loss': info_tensor[0].item(),
                    'cls_loss': info_tensor[1].item()
                })

        self.scheduler.step()

    def _test_epoch(self, test_dataloader):
        with torch.no_grad():
            if self.is_master:
                pbar = tqdm(desc=f'Test, epoch #{self.last_epoch}',
                            total=len(test_dataloader))

            self.encoder.eval()

            loss, acc, quality_metric = AvgMeter(), AvgMeter(), 0
            for images, labels in test_dataloader:
                images = images.to(self.device)
                labels = labels.to(self.device)

                label_preds = self.encoder(images)
                loss_val = self.ce_criterion(label_preds, labels).mean()
                acc_val = (torch.argmax(label_preds,
                                        dim=-1) == labels).float().mean()

                loss.update(loss_val.item())
                acc.update(acc_val.item())

                info_tensor = torch.tensor([loss(), acc()], device=self.device)
                torch.distributed.reduce(info_tensor, dst=0)

                if self.is_master:
                    info_tensor = info_tensor / torch.distributed.get_world_size(
                    )
                    quality_metric = info_tensor[1].item()
                    pbar.update(1)
                    pbar.set_postfix({
                        'loss': info_tensor[0].item(),
                        'acc': info_tensor[1].item()
                    })

            return quality_metric

    def _save_checkpoint(self, checkpoint_path):
        os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
        torch.save(self.encoder.module.state_dict(), checkpoint_path)

    def train(self,
              train_data,
              n_epochs,
              batch_size,
              test_data=None,
              last_checkpoint_path=None,
              best_checkpoint_path=None):

        num_replicas = torch.distributed.get_world_size()
        rank = torch.distributed.get_rank()
        batch_size = batch_size // num_replicas

        train_sampler = DistributedSampler(train_data,
                                           shuffle=True,
                                           num_replicas=num_replicas,
                                           rank=rank)
        train_dataloader = DataLoader(train_data,
                                      batch_size=batch_size,
                                      sampler=train_sampler,
                                      num_workers=self.n_jobs)

        if test_data is not None:
            test_sampler = DistributedSampler(test_data,
                                              shuffle=False,
                                              num_replicas=num_replicas,
                                              rank=rank)
            test_dataloader = DataLoader(test_data,
                                         batch_size=batch_size,
                                         sampler=test_sampler,
                                         num_workers=self.n_jobs)

        best_metric = float('-inf')
        for epoch in range(n_epochs):
            torch.cuda.empty_cache()
            self._train_epoch(train_dataloader)

            if last_checkpoint_path is not None and self.is_master:
                self._save_checkpoint(last_checkpoint_path)

            if test_data is not None:
                torch.cuda.empty_cache()
                metric = self._test_epoch(test_dataloader)

                if best_checkpoint_path is not None and self.is_master:
                    if metric > best_metric:
                        best_metric = metric
                        self._save_checkpoint(best_checkpoint_path)

            self.last_epoch += 1
Пример #5
0
def create_optimizer(cfg, model, filter_bias_and_bn=True):
    opt_lower = cfg.SOLVER.OPTIMIZER.lower()
    lr = cfg.SOLVER.BASE_LR
    weight_decay = cfg.SOLVER.WEIGHT_DECAY
    momentum = cfg.SOLVER.MOMENTUM
    if 'adamw' in opt_lower or 'radam' in opt_lower:
        # Compensate for the way current AdamW and RAdam optimizers apply LR to the weight-decay
        # I don't believe they follow the paper or original Torch7 impl which schedules weight
        # decay based on the ratio of current_lr/initial_lr
        weight_decay /= lr
    if weight_decay and filter_bias_and_bn:
        parameters = []
        for key, value in model.named_parameters():
            if not value.requires_grad:
                continue
            filtered_lr = lr
            filtered_weight_decay = weight_decay
            if "bias" in key:
                filtered_lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR
                filtered_weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS
            parameters += [{
                "params": [value],
                "lr": filtered_lr,
                "weight_decay": filtered_weight_decay
            }]
        weight_decay = 0.
    else:
        parameters = model.parameters()

    opt_split = opt_lower.split('_')
    opt_name = opt_split[-1]
    if opt_name == 'sgd':
        optimizer = optim.SGD(parameters,
                              lr=lr,
                              momentum=momentum,
                              weight_decay=weight_decay,
                              nesterov=True)
    elif opt_name == 'adam':
        optimizer = optim.Adam(parameters, lr=lr, weight_decay=weight_decay)
    elif opt_name == 'adamw':
        optimizer = AdamW(parameters, lr=lr, weight_decay=weight_decay)
    elif opt_name == 'nadam':
        optimizer = Nadam(parameters, lr=lr, weight_decay=weight_decay)
    elif opt_name == 'radam':
        optimizer = RAdam(parameters, lr=lr, weight_decay=weight_decay)
    elif opt_name == 'adadelta':
        optimizer = optim.Adadelta(parameters,
                                   lr=lr,
                                   weight_decay=weight_decay)
    elif opt_name == 'rmsprop':
        optimizer = optim.RMSprop(parameters,
                                  lr=lr,
                                  alpha=0.9,
                                  momentum=momentum,
                                  weight_decay=weight_decay)
    elif opt_name == 'rmsproptf':
        optimizer = RMSpropTF(parameters,
                              lr=lr,
                              alpha=0.9,
                              momentum=momentum,
                              weight_decay=weight_decay)
    elif opt_name == 'novograd':
        optimizer = NovoGrad(parameters, lr=lr, weight_decay=weight_decay)
    elif opt_name == 'nvnovograd':
        optimizer = NvNovoGrad(parameters, lr=lr, weight_decay=weight_decay)
    else:
        raise ValueError("Invalid optimizer")

    if len(opt_split) > 1:
        if opt_split[0] == 'lookahead':
            optimizer = Lookahead(optimizer)
    return optimizer
Пример #6
0
def main():
    args = parse_args()
    conf = Config(args.conf)

    data_dir = conf.data_dir
    fold_id = conf.fold_id

    workspace = Workspace(conf.run_id).setup()
    workspace.save_conf(args.conf)
    workspace.log(f'{conf.to_dict()}')

    torch.cuda.set_device(0)

    if conf.use_augmentor:
        if conf.augmentor_type == 'v1':
            augmentor = create_augmentor_v1(
                enable_random_morph=conf.enable_random_morph)
        elif conf.augmentor_type == 'v2':
            augmentor = create_augmentor_v2(
                enable_random_morph=conf.enable_random_morph,
                invert_color=conf.invert_color)
        elif conf.augmentor_type == 'v3':
            if conf.input_size_tuple:
                input_size = tuple(conf.input_size_tuple)
            else:
                input_size = (conf.input_size, conf.input_size) if conf.input_size else \
                             (SOURCE_IMAGE_HEIGHT, SOURCE_IMAGE_WIDTH)
            augmentor = create_augmentor_v3(
                input_size,
                enable_random_morph=conf.enable_random_morph,
                invert_color=conf.invert_color)
        else:
            raise ValueError(conf.augmentor_type)
        workspace.log(f'Use augmentor: {conf.augmentor_type}')
    else:
        augmentor = None

    if not conf.input_size_tuple and conf.input_size == 0:
        train_transformer = create_transformer_v1(augmentor=augmentor)
        val_transformer = create_testing_transformer_v1()
        workspace.log('Input size: default')
    else:
        if conf.input_size_tuple:
            input_size = tuple(conf.input_size_tuple)
        else:
            input_size = (conf.input_size, conf.input_size)
        train_transformer = create_transformer_v1(input_size=input_size,
                                                  augmentor=augmentor)
        val_transformer = create_testing_transformer_v1(input_size=input_size)
        workspace.log(f'Input size: {input_size}')

    train_dataset, val_dataset = bengali_dataset(
        data_dir,
        fold_id=fold_id,
        train_transformer=train_transformer,
        val_transformer=val_transformer,
        invert_color=conf.invert_color,
        n_channel=conf.n_channel,
        use_grapheme_code=conf.use_grapheme_code,
        logger=workspace.logger)
    workspace.log(f'#train={len(train_dataset)}, #val={len(val_dataset)}')
    train_dataset.set_low_freq_groups(n_class=conf.n_class_low_freq)

    if conf.sampler_type == 'pk':
        sampler = PKSampler(train_dataset,
                            n_iter_per_epoch=conf.n_iter_per_epoch,
                            p=conf.batch_p,
                            k=conf.batch_k)
        train_loader = DataLoader(train_dataset,
                                  shuffle=False,
                                  num_workers=8,
                                  pin_memory=True,
                                  batch_sampler=sampler)
        workspace.log(f'{sampler} is enabled')
        workspace.log(f'Real batch_size={sampler.batch_size}')
    elif conf.sampler_type == 'random+append':
        batch_sampler = LowFreqSampleMixinBatchSampler(
            train_dataset,
            conf.batch_size,
            n_low_freq_samples=conf.n_low_freq_samples,
            drop_last=True)
        train_loader = DataLoader(train_dataset,
                                  shuffle=False,
                                  num_workers=8,
                                  pin_memory=True,
                                  batch_sampler=batch_sampler)
        workspace.log(f'{batch_sampler} is enabled')
        workspace.log(f'Real batch_size={batch_sampler.batch_size}')
    elif conf.sampler_type == 'random':
        train_loader = DataLoader(train_dataset,
                                  batch_size=conf.batch_size,
                                  shuffle=True,
                                  num_workers=8,
                                  pin_memory=True,
                                  drop_last=True)
    else:
        raise ValueError(f'Invalid sampler_type: {conf.sampler_type}')

    val_loader = DataLoader(val_dataset,
                            batch_size=conf.batch_size,
                            shuffle=False,
                            num_workers=8,
                            pin_memory=True)

    workspace.log(f'Create init model: arch={conf.arch}')
    model = create_init_model(conf.arch,
                              pretrained=True,
                              pooling=conf.pooling_type,
                              dim=conf.feat_dim,
                              use_maxblurpool=conf.use_maxblurpool,
                              remove_last_stride=conf.remove_last_stride,
                              n_channel=conf.n_channel)
    if conf.weight_file:
        pretrained_weight = torch.load(conf.weight_file, map_location='cpu')
        result = model.load_state_dict(pretrained_weight)
        workspace.log(f'Pretrained weights were loaded: {conf.weight_file}')
        workspace.log(result)

    model = model.cuda()

    sub_models = []

    criterion_g = get_criterion(conf.loss_type_g,
                                weight=train_dataset.get_class_weights_g(),
                                rate=conf.ohem_rate)
    workspace.log(f'Loss type (g): {conf.loss_type_g}')

    criterion_v = get_criterion(conf.loss_type_v,
                                weights=train_dataset.get_class_weights_v(),
                                rate=conf.ohem_rate)
    workspace.log(f'Loss type (v): {conf.loss_type_v}')

    criterion_c = get_criterion(conf.loss_type_c,
                                weights=train_dataset.get_class_weights_c(),
                                rate=conf.ohem_rate)
    workspace.log(f'Loss type (c): {conf.loss_type_c}')

    if conf.loss_type_feat_g != 'none':
        assert isinstance(
            model, (M.BengaliResNet34V3, M.BengaliResNet34V4,
                    M.BengaliResNet34AGeMV4, M.BengaliSEResNeXt50V4,
                    M.BengaliEfficientNetB0V4, M.BengaliEfficientNetB3V4))
        criterion_feat_g = get_criterion(conf.loss_type_feat_g,
                                         dim=model.multihead.head_g.dim,
                                         n_class=168,
                                         s=conf.af_scale_g)
        workspace.log(f'Loss type (fg): {conf.loss_type_feat_g}')
        if conf.loss_type_feat_g in ('af', ):
            sub_models.append(criterion_feat_g)
            workspace.log('Add criterion_feat_g to sub model')
    else:
        criterion_feat_g = None

    if conf.loss_type_feat_v != 'none':
        assert isinstance(
            model, (M.BengaliResNet34V3, M.BengaliResNet34V4,
                    M.BengaliResNet34AGeMV4, M.BengaliSEResNeXt50V4,
                    M.BengaliEfficientNetB0V4, M.BengaliEfficientNetB3V4))
        criterion_feat_v = get_criterion(conf.loss_type_feat_v,
                                         dim=model.multihead.head_v.dim,
                                         n_class=11,
                                         s=conf.af_scale_v)
        workspace.log(f'Loss type (fv): {conf.loss_type_feat_v}')
        if conf.loss_type_feat_v in ('af', ):
            sub_models.append(criterion_feat_v)
            workspace.log('Add criterion_feat_v to sub model')
    else:
        criterion_feat_v = None

    if conf.loss_type_feat_c != 'none':
        assert isinstance(
            model, (M.BengaliResNet34V3, M.BengaliResNet34V4,
                    M.BengaliResNet34AGeMV4, M.BengaliSEResNeXt50V4,
                    M.BengaliEfficientNetB0V4, M.BengaliEfficientNetB3V4))
        criterion_feat_c = get_criterion(conf.loss_type_feat_c,
                                         dim=model.multihead.head_c.dim,
                                         n_class=7,
                                         s=conf.af_scale_c)
        workspace.log(f'Loss type (fc): {conf.loss_type_feat_c}')
        if conf.loss_type_feat_c in ('af', ):
            sub_models.append(criterion_feat_c)
            workspace.log('Add criterion_feat_c to sub model')
    else:
        criterion_feat_c = None

    if conf.use_grapheme_code:
        workspace.log('Use grapheme code classifier')
        grapheme_classifier = nn.Sequential(nn.BatchNorm1d(168 + 11 + 7),
                                            nn.Linear(168 + 11 + 7, 1295))
        grapheme_classifier = grapheme_classifier.cuda()
        grapheme_classifier.train()
        sub_models.append(grapheme_classifier)
        criterion_grapheme = L.OHEMCrossEntropyLoss().cuda()
    else:
        grapheme_classifier = None
        criterion_grapheme = None

    parameters = [{'params': model.parameters()}] + \
                 [{'params': sub_model.parameters()} for sub_model in sub_models]

    if conf.optimizer_type == 'adam':
        optimizer = torch.optim.Adam(parameters, lr=conf.lr)
    elif conf.optimizer_type == 'sgd':
        optimizer = torch.optim.SGD(parameters,
                                    lr=conf.lr,
                                    momentum=0.9,
                                    weight_decay=1e-4)
    elif conf.optimizer_type == 'ranger':
        optimizer = Ranger(parameters, lr=conf.lr, weight_decay=1e-4)
    elif conf.optimizer_type == 'radam':
        optimizer = RAdam(parameters, lr=conf.lr, weight_decay=1e-4)
    else:
        raise ValueError(conf.optimizer_type)
    workspace.log(f'Optimizer type: {conf.optimizer_type}')

    if conf.use_apex:
        workspace.log('Apex initialization')
        _models, optimizer = amp.initialize([model] + sub_models,
                                            optimizer,
                                            opt_level=conf.apex_opt_level)
        if len(_models) == 1:
            model = _models[0]
        else:
            model = _models[0]
            criterion_feat_g = _models[1]
            criterion_feat_v = _models[2]
            criterion_feat_c = _models[3]
        workspace.log('Initialized by Apex')
        workspace.log(f'{optimizer.__class__.__name__}')
        for m in _models:
            workspace.log(f'{m.__class__.__name__}')

    if conf.scheduler_type == 'cosanl':
        scheduler = CosineLRWithRestarts(
            optimizer,
            conf.batch_size,
            len(train_dataset),
            restart_period=conf.cosanl_restart_period,
            t_mult=conf.cosanl_t_mult)
        workspace.log(f'restart_period={scheduler.restart_period}')
        workspace.log(f't_mult={scheduler.t_mult}')
    elif conf.scheduler_type == 'rop':
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            patience=conf.rop_patience,
            mode='max',
            factor=conf.rop_factor,
            min_lr=1e-6,
            verbose=True)
    else:
        raise ValueError(conf.scheduler_type)

    train(model,
          train_loader,
          val_loader,
          optimizer,
          criterion_g,
          criterion_v,
          criterion_c,
          criterion_feat_g,
          criterion_feat_v,
          criterion_feat_c,
          workspace,
          scheduler=scheduler,
          n_epoch=conf.n_epoch,
          cutmix_prob=conf.cutmix_prob,
          mixup_prob=conf.mixup_prob,
          freeze_bn_epochs=conf.freeze_bn_epochs,
          feat_loss_weight=conf.feat_loss_weight,
          use_apex=conf.use_apex,
          decrease_ohem_rate=conf.decrease_ohem_rate,
          use_grapheme_code=conf.use_grapheme_code,
          grapheme_classifier=grapheme_classifier,
          criterion_grapheme=criterion_grapheme,
          final_ft=conf.final_ft)
Пример #7
0
class IQN(nn.Module):
    def __init__(self, state_space, act_n, quantile_dim, num_quantiles,
                 hidden_dim, num_hidden, optim_params):
        """
        Rainbow Recurrent IQN

        IQN: https://arxiv.org/pdf/1806.06923.pdf
        R2D2: https://openreview.net/pdf?id=r1lyTjAqYX
        R2D3: https://arxiv.org/abs/1909.01387
        """
        nn.Module.__init__(self)

        self.online = Model(state_space, act_n, quantile_dim, num_quantiles,
                            hidden_dim, num_hidden)
        self.target = deepcopy(self.online)

        self.loss_func = nn.SmoothL1Loss(reduction="mean")
        self.optim = RAdam(self.online.parameters(), **optim_params)

    def forward(self, inp):
        return self.online(inp)

    def step(self, state, greedy=False):
        """
        Takes a step into the environment
        """
        return self.online.step(state, greedy)

    def train_batch(self, rollouts, burn_in_length, sequence_length):
        """
        Trains for a batch of rollouts with the given burn in length and
        training sequence length
        """
        self.optim.zero_grad()

        states, actions, rewards, next_states, terminals, hidden_state = rollouts

        # Add burn in here #######
        
        next_q_vals, next_quantile_vals, next_quantiles, next_hidden = self.target(next_states)
        num_quantiles = next_quantile_vals[1]

        next_actions = next_quantile_vals.argmax(-1, keepdim=1)
        next_actions = next_actions.unsqueeze(1).repeat(1, num_quantiles, 1)
        next_values = next_quantile_vals.gather(-1, next_actions).squeeze(1)

        q_vals, quantile_vals, quantiles = self.online(states)
        action_values = quantile_vals.gather(-1, actions)

        td_error = next_values.unsqueeze(2) - action_values.unsqueeze(1)
        quantile_loss = self.loss_func(next_values.unsqueeze(2),
                                       action_values.unsqueeze(1))

        quantiles = quantiles.unsqueeze(1).repeat(1, self.num_quantiles, 1)
        penalty = torch.abs(quantiles - (td_error < 0).float().detach())

        loss = penalty * quantile_loss # Divide by huber kappa
        loss = loss.sum(2).mean(1)
        meaned_loss = loss.mean(1)

        meaned_loss.backward()
        self.optim.step()

        return meaned_loss, loss

    def train(self, num_batches, batch_size, burn_in_length, sequence_length,
              online_replay_buffer=None, supervised_replay_buffer=None,
              supervised_chance=0.25, writer=None):
        """
        Trains R2D3 style with 2 replay buffers
        """
        assert not online_replay_buffer == supervised_replay_buffer == None

        for batch in range(1, num_batches + 1):
            buff_choice = np.rand()
            if(online_replay_buffer is None or buff_choice < supervised_chance):
                replay_buffer = supervised_replay_buffer
            else:
                replay_buffer = online_replay_buffer

            while(not replay_buffer.ready_to_sample(batch_size)):
                pass

            rollouts, idxs, is_weights = replay_buffer.sample(batch_size)

            loss, new_errors = self.train_batch(rollouts, burn_in_length,
                                                sequence_length)
            replay_buffer.update_priorities(new_errors, idxs)

            if(writer is not None):
                if(buff_choice < supervised_chance):
                    writer.add_summary("Supervised Loss", loss, batch)
                else:
                    writer.add_summary("Online Loss", loss, batch)

                writer.add_summary("Loss", loss, batch)

    def update_target(self):
        """
        Updates the target network
        """
        self.target.load_state_dict(self.online.state_dict())
Пример #8
0
def main():
    cifar_train = CIFAR10('.',
                          train=True,
                          transform=transforms.Compose([
                              transforms.Resize((224, 224)),
                              transforms.ToTensor()
                          ]),
                          download=True)
    cifar_test = CIFAR10('.',
                         train=False,
                         transform=transforms.Compose([
                             transforms.Resize((224, 224)),
                             transforms.ToTensor()
                         ]),
                         download=True)

    dl_train = DataLoader(cifar_train, batch_size=16)
    dl_test = DataLoader(cifar_test, batch_size=16)

    logdir = "./logdir/Adam"
    num_epochs = 10

    loaders = {'train': dl_train, 'valid': dl_test}

    model = resnet34()
    for name, param in model.named_parameters():
        param.requires_grad = True

    model.train()
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters())
    runner = dl.SupervisedRunner()

    runner.train(
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        loaders=loaders,
        num_epochs=num_epochs,
        verbose=True,
        logdir=logdir,
        callbacks=[
            logger.TensorboardLogger(),
            AccuracyCallback(num_classes=10)
        ],
    )

    logdir = "./logdir/AdamW"

    model.apply(init_weights)
    optimizer = AdamW()
    runner.train(
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        loaders=loaders,
        num_epochs=num_epochs,
        verbose=True,
        logdir=logdir,
        callbacks=[
            logger.TensorboardLogger(),
            AccuracyCallback(num_classes=10)
        ],
    )

    logdir = "./logdir/RAdam"

    model.apply(init_weights)
    optimizer = RAdam()
    runner.train(
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        loaders=loaders,
        num_epochs=num_epochs,
        verbose=True,
        logdir=logdir,
        callbacks=[
            logger.TensorboardLogger(),
            AccuracyCallback(num_classes=10)
        ],
    )
def train(args,
          persuasive_data_iter,
          tree_data_iter,
          model,
          criterion,
          device,
          multitask=False):
    # initial for training
    model.train()
    # build up optimizer
    if (args.optimizer == "Adam"):
        optimizer = optim.Adam(model.parameters(),
                               lr=args.lr,
                               weight_decay=args.weight_decay)
    elif (args.optimizer == 'AdamW'):
        optimizer = optim.AdamW(model.parameters(),
                                lr=args.lr,
                                weight_decay=args.weight_decay)
    elif (args.optimizer == 'Ranger'):
        optimizer = Ranger(model.parameters(),
                           lr=args.lr,
                           weight_decay=args.weight_decay)
    elif (args.optimizer == 'Radam'):
        optimizer = RAdam(model.parameters(),
                          lr=args.lr,
                          weight_decay=args.weight_decay)
    elif (args.optimizer == 'SGD'):
        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr * 1000,
                              momentum=0.9,
                              weight_decay=args.weight_decay)
    else:
        raise NotImplementedError
    scheduler = optim.lr_scheduler.StepLR(optimizer,
                                          step_size=args.lr_step,
                                          gamma=args.lr_gamma)
    grad_clip = args.grad_clip
    save_path = args.save_path
    accumulate = args.accumulate
    print_every = 100 * accumulate
    eval_every = 25 * accumulate
    #print_every, eval_every = 2, 2

    total_epoch = args.epoch * len(persuasive_data_iter[0])
    print('total training step:', total_epoch)
    persuasive_datas = iter(persuasive_data_iter[0])
    if ((tree_data_iter[0] is not None) and multitask):
        tree_datas = iter(tree_data_iter[0])

    multi_alpha = (args.ac_type_alpha, args.link_type_alpha)
    direct_alpha = (args.adu_alpha, args.para_alpha)
    alpha = (direct_alpha, multi_alpha)
    tree_count = args.tree_count
    best_acc = [0, 0]

    # start training
    model.zero_grad()
    t = time.time()
    persuasive = [[[], [], []], [[], [], []]]
    tree_preds = {
        'label': collections.defaultdict(list),
        'pred': collections.defaultdict(list),
        'loss': collections.defaultdict(float),
        'count': 0
    }
    for count in range(1, total_epoch + 1):
        try:
            datas = next(persuasive_datas)
        except:
            persuasive_datas = iter(persuasive_data_iter[0])
            datas = next(persuasive_datas)

        outputs = []
        for data in datas:
            data = convert(data, device)
            pred = model(**data)
            outputs.append(pred)
        labels = {
            'adu_direct': [datas[0]['author'], datas[1]['author']],
            'para_direct': [datas[0]['para_author'], datas[1]['para_author']],
        }
        # simply compare two value
        loss, outputs, labels = persuasive_cal_score(outputs, labels,
                                                     criterion, direct_alpha)

        for i, (p, l) in enumerate(zip(outputs, labels)):
            persuasive[0][i].append(p)
            persuasive[1][i].append(l)

        loss.backward()

        if (multitask and (count % tree_count == 0)):
            try:
                data, label = next(tree_datas)
            except:
                tree_datas = iter(tree_data_iter[0])
                data, label = next(tree_datas)
            data = convert(data, device)
            output = model(**data, multitask=True)
            output = {
                'type': output[0],
                'link': output[1],
                'link_type': output[2]
            }
            label = convert(label, device)
            loss, loss_stat, output = tree_cal_score(output, label, None,
                                                     multi_alpha)
            loss.backward()

            for key, val in label.items():
                tree_preds['label'][key].append(val.detach().cpu())
            for key, val in loss_stat.items():
                tree_preds['loss'][key] += val
            update_pred(tree_preds['pred'], output, data['adu_length'])
            tree_preds['count'] += 1

        if (count % accumulate == 0):
            #utils.clip_grad_norm_(model.parameters(), grad_clip)

            optimizer.step()
            optimizer.zero_grad()

        if (count % eval_every == 0):
            stat = update(persuasive, criterion, dtype='persuasive')

            nt = time.time()
            print(
                'now:{}, time: {:.4f}s'.format(count, nt - t),
                '\npersuasive: [loss: {:.4f}, diff: {:.4f}, acc: {:.4f}]'.
                format(stat[0][0], stat[0][1], stat[0][2]),
                '\tdirect: [adu_loss: {:.4f}, adu_acc: {:.4f}, para_loss: {:.4f}, para_acc: {:.4f}]'
                .format(stat[1][0], stat[1][1], stat[2][0], stat[2][1]),
                flush=True)

            if (multitask):
                stat = update(tree_preds, dtype='tree')
                print(
                    'acc: [link_mst: {:.4f}, link: {:.4f}, type: {:.4f}, link_type: {:.4f}]'
                    .format(stat['acc']['link_mst'], stat['acc']['link'],
                            stat['acc']['type'], stat['acc']['link_type']))
                print(
                    'f1: type: [premise: {:.4f}, claim: {:.4f}], link_type: [support{:.4f}, attack: {:.4f}]'
                    .format(stat['type']['premise'], stat['type']['claim'],
                            stat['link_type']['support'],
                            stat['link_type']['attack']))
                print('mrr: {:.4f}'.format(stat['mrr_link']), flush=True)
            t = nt

            persuasive = [[[], [], []], [[], [], []]]
            tree_preds = {
                'label': collections.defaultdict(list),
                'pred': collections.defaultdict(list),
                'loss': collections.defaultdict(float),
                'count': 0
            }
            scheduler.step()

        if (count % print_every == 0):
            dev_acc = test('dev {}'.format(count), persuasive_data_iter[1],
                           tree_data_iter[1], model, criterion, device, alpha)
            test_acc = test('test {}'.format(count), persuasive_data_iter[2],
                            tree_data_iter[2], model, criterion, device, alpha)
            if (dev_acc > best_acc[0]):
                best_acc = [dev_acc, test_acc]
            torch.save(model.state_dict(),
                       save_path + '/check_{}.pt'.format(count))
    print('all finish with acc:', best_acc)