예제 #1
0
def main():

    args = parse_option()
    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))
    # create model and optimizer
    # == dataset config==
    """
    CUDA_VISIBLE_DEVICES=0,1 python train_temporal_dis.py \
     --batch_size 16 --num_workers 8 --nce_k 3569 --softmax --moco
    """

    # args.dataset = 'hmdb51'
    # args.train_list = '../datasets/lists/hmdb51/hmdb51_rgb_train_split_1.txt'
    # args.val_list = '../datasets/lists/hmdb51/hmdb51_rgb_val_split_1.txt'
    """
    CUDA_VISIBLE_DEVICES=1 python train_temporal_dis.py \
     --batch_size 16 --num_workers 8 --nce_k 9536 --softmax --moco
    """
    # args.print_freq = 100
    # args.dataset = 'ucf101'
    # args.train_list = '../datasets/lists/ucf101/ucf101_rgb_train_split_1.txt'
    # args.val_list = '../datasets/lists/ucf101/ucf101_rgb_val_split_1.txt'

    # args.print_freq = 1000
    # args.dataset = 'kinetics'
    # args.train_list = '../datasets/lists/kinetics-400/ssd_kinetics_video_trainlist.txt'
    # args.val_list = '../datasets/lists/kinetics-400/ssd_kinetics_video_vallist.txt'

    args.dropout = 0.5
    args.clips = 1
    args.data_length = 16
    args.stride = 4
    args.spatial_size = 224
    args.root = ""
    args.mode = 'rgb'
    args.eval_indict = 'loss'
    args.pt_loss = 'TemporalDis'
    args.workers = 4
    # args.arch = 'i3d' # 'r2p1d'
    num_class, data_length, image_tmpl = data_config(args)
    train_transforms, test_transforms, eval_transforms = augmentation_config(
        args)
    train_loader, val_loader, eval_loader, train_samples, val_samples, eval_samples = data_loader_init(
        args, data_length, image_tmpl, train_transforms, test_transforms,
        eval_transforms)

    n_data = len(train_loader)
    if args.arch == 'i3d':
        model = I3D(num_classes=101,
                    modality=args.mode,
                    dropout_prob=args.dropout,
                    with_classifier=False)
        model_ema = I3D(num_classes=101,
                        modality=args.mode,
                        dropout_prob=args.dropout,
                        with_classifier=False)
    elif args.arch == 'r2p1d':
        model = R2Plus1DNet((1, 1, 1, 1),
                            num_classes=num_class,
                            with_classifier=False)
        model_ema = R2Plus1DNet((1, 1, 1, 1),
                                num_classes=num_class,
                                with_classifier=False)
    elif args.arch == 'r3d':
        from model.r3d import resnet18
        model = resnet18(num_classes=num_class, with_classifier=False)
        model_ema = resnet18(num_classes=num_class, with_classifier=False)
    else:
        Exception("Not implemene error!")
    model = torch.nn.DataParallel(model)
    model_ema = torch.nn.DataParallel(model_ema)
    # random initialization
    model.apply(weights_init)
    model_ema.apply(weights_init)
    # copy weights from `model' to `model_ema'
    moment_update(model, model_ema, 0)
    contrast = MemoryMoCo(128, n_data, args.nce_k, args.nce_t,
                          args.softmax).cuda(args.gpu)
    # contrast2 = MemoryMoCo(128, n_data, args.nce_k, args.nce_t, args.softmax).cuda(args.gpu)
    criterion = NCESoftmaxLoss() if args.softmax else NCECriterion(n_data)
    criterion = criterion.cuda(args.gpu)
    cls_criterion = nn.CrossEntropyLoss().cuda()

    model = model.cuda()
    if args.moco:
        model_ema = model_ema.cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    cudnn.benchmark = True

    if args.amp:
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.opt_level)
        if args.moco:
            optimizer_ema = torch.optim.SGD(model_ema.parameters(),
                                            lr=0,
                                            momentum=0,
                                            weight_decay=0)
            model_ema, optimizer_ema = amp.initialize(model_ema,
                                                      optimizer_ema,
                                                      opt_level=args.opt_level)

    # optionally resume from a checkpoint
    args.start_epoch = 1
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cpu')
            # checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch'] + 1
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            contrast.load_state_dict(checkpoint['contrast'])
            if args.moco:
                model_ema.load_state_dict(checkpoint['model_ema'])

            if args.amp and checkpoint['opt'].amp:
                print('==> resuming amp state_dict')
                amp.load_state_dict(checkpoint['amp'])

            print("=> loaded successfully '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            del checkpoint
            torch.cuda.empty_cache()
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # tensorboard
    logger = tb_logger.Logger(logdir=args.tb_folder, flush_secs=2)
    logger2 = tb_logger.Logger(logdir=args.tb_folder2, flush_secs=2)

    #==================================== our data augmentation method=================================
    pos_aug = GenPositive()
    neg_aug = GenNegative()

    # routine
    for epoch in range(args.start_epoch, args.epochs + 1):

        adjust_learning_rate(epoch, args, optimizer)
        print("==> training...")

        time1 = time.time()
        loss, prob = train_moco(epoch, train_loader, model, model_ema,
                                contrast, criterion, optimizer, args, pos_aug,
                                neg_aug)
        time2 = time.time()
        print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))
        saving(logger, loss, epoch, optimizer, args, model, contrast, prob,
               model_ema, 'TemporalDis')
netD, optimizerD = amp.initialize(netD,
                                  optimizerD,
                                  opt_level="O2",
                                  keep_batchnorm_fp32=True,
                                  loss_scale="dynamic")

netG, optimizerG = amp.initialize(netG,
                                  optimizerG,
                                  opt_level="O2",
                                  keep_batchnorm_fp32=True,
                                  loss_scale="dynamic")

if have_checkpoint:
    optimizerD.load_state_dict(checkpoint['optimizerD_state_dict'])
    optimizerG.load_state_dict(checkpoint['optimizerG_state_dict'])
    amp.load_state_dict(checkpoint['amp'])

# Training Loop

# Lists to keep track of progress
G_losses = []
D_losses = []
iters = 0

final_epoch = num_epochs + completed_epochs

print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
    start_time = datetime.now()
예제 #3
0
 def load_listener(self):
     if self.listener_checkpointer.has_checkpoint():
         extra_listener_checkpoint_data = self.listener_checkpointer.load()
         amp.load_state_dict(extra_listener_checkpoint_data['amp'])
예제 #4
0
    def _restore_checkpoint(self) -> int:
        """
        Restores the model and training state from the last saved checkpoint.
        This includes an epoch count and optimizer state, which is serialized separately
        from model parameters. This function should only be used to continue training -
        if you wish to load a model for inference/load parts of a model into a new
        computation graph, you should use the native Pytorch functions:
        ` model.load_state_dict(torch.load("/path/to/model/weights.th"))`

        If `self._serialization_dir` does not exist or does not contain any checkpointed weights,
        this function will do nothing and return 0.

        # Returns

        epoch: int
            The epoch at which to resume training, which should be one after the epoch
            in the saved training state.
        """
        model_state, training_state = self._checkpointer.restore_checkpoint()

        if not training_state:
            # No checkpoint to restore, start at 0
            return 0

        # The apex docs recommend calling amp.initialize before calling load_state_dict.
        if self._opt_level is not None and "amp" in training_state:
            self.model, self.optimizer = amp.initialize(
                self.model, self.optimizer, opt_level=self._opt_level)
            amp.load_state_dict(training_state["amp"])
        self.model.load_state_dict(model_state)
        self.optimizer.load_state_dict(training_state["optimizer"])
        if (self._learning_rate_scheduler is not None
                and "learning_rate_scheduler" in training_state):
            self._learning_rate_scheduler.load_state_dict(
                training_state["learning_rate_scheduler"])
        if self._momentum_scheduler is not None and "momentum_scheduler" in training_state:
            self._momentum_scheduler.load_state_dict(
                training_state["momentum_scheduler"])
        training_util.move_optimizer_to_cuda(self.optimizer)

        # Currently the `training_state` contains a serialized `MetricTracker`.
        if "metric_tracker" in training_state:
            self._metric_tracker.load_state_dict(
                training_state["metric_tracker"])
        # It used to be the case that we tracked `val_metric_per_epoch`.
        elif "val_metric_per_epoch" in training_state:
            self._metric_tracker.clear()
            self._metric_tracker.add_metrics(
                training_state["val_metric_per_epoch"])
        # And before that we didn't track anything.
        else:
            self._metric_tracker.clear()

        if isinstance(training_state["epoch"], int):
            epoch_to_return = training_state["epoch"] + 1
        else:
            epoch_to_return = int(training_state["epoch"].split(".")[0]) + 1

        # For older checkpoints with batch_num_total missing, default to old behavior where
        # it is unchanged.
        batch_num_total = training_state.get("batch_num_total")
        if batch_num_total is not None:
            self._batch_num_total = batch_num_total

        return epoch_to_return
예제 #5
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', required=True)
    parser.add_argument(
        '--dataset',
        default='/home/nax/Downloads/shopee-product-matching/train.csv')
    parser.add_argument(
        '--train-label-split',
        default='/home/nax/Downloads/shopee-product-matching/train_labels.csv')
    parser.add_argument('--config', default='configs/baseline.py')
    parser.add_argument('--apex', action='store_true')
    parser.add_argument('--embedding-size', type=int)
    parser.add_argument('--batch-size', type=int)
    parser.add_argument('--image-size', type=int)
    parser.add_argument('--use_cuda', action='store_true')

    args = parser.parse_args()

    config = util.load_config(args.config)

    util.update_args(args, config)

    if args.apex:
        from apex import amp

    train_labels = np.loadtxt(args.train_label_split, dtype=np.int64)
    val_labels = data.get_val_labels(args.dataset, set(train_labels))
    val_labels = list(val_labels)
    #val_dataset = data.DMLDataset(args.dataset,
    #        image_size=args.image_size,
    #        is_training=False,
    #        onehot_labels=False,
    #        subset_labels=val_labels)
    val_dataset = data.DMLDataset(args.dataset,
                                  image_size=args.image_size,
                                  is_training=False,
                                  onehot_labels=False)
    val_loader = data_util.DataLoader(val_dataset,
                                      batch_size=args.batch_size,
                                      num_workers=2,
                                      collate_fn=val_dataset.collate_fn)
    backbone, embeddings, model, states = model_loader.load_model(
        config, args, args.model)

    model.eval()

    if not args.apex:
        model = torch.nn.DataParallel(model)
    model = model.cuda()

    if args.apex:
        model = amp.initialize(model, opt_level='O1')
        model = torch.nn.DataParallel(model)

    if args.apex:
        amp.load_state_dict(states['amp'])

    model.eval()

    print(
        f'Val accuracy: {eval_utils.evaluate(model, val_loader, use_cuda=args.use_cuda)}'
    )
    print(
        f'F1: {eval_utils.f1_evaluate(model, val_loader, use_cuda=args.use_cuda)}'
    )
    print(
        f'F1: {eval_utils.f1_evaluate(model, val_loader, threshold=1.070329, use_cuda=args.use_cuda)}'
    )
def main():
    setup_default_logging()
    args, args_text = _parse_args()

    args.prefetcher = not args.no_prefetcher
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
    args.device = 'cuda:0'
    args.world_size = 1
    args.rank = 0  # global rank
    if args.distributed:
        args.device = 'cuda:%d' % args.local_rank
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()
        args.rank = torch.distributed.get_rank()
    assert args.rank >= 0

    if args.distributed:
        logging.info(
            'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
            % (args.rank, args.world_size))
    else:
        logging.info('Training with a single process on 1 GPU.')

    torch.manual_seed(args.seed + args.rank)

    # create model
    config = get_efficientdet_config(args.model)
    config.redundant_bias = args.redundant_bias  # redundant conv + BN bias layers (True to match official models)
    model = EfficientDet(config)
    model = DetBenchTrain(model, config)

    # FIXME create model factory, pretrained zoo
    # model = create_model(
    #     args.model,
    #     pretrained=args.pretrained,
    #     num_classes=args.num_classes,
    #     drop_rate=args.drop,
    #     drop_connect_rate=args.drop_connect,  # DEPRECATED, use drop_path
    #     drop_path_rate=args.drop_path,
    #     drop_block_rate=args.drop_block,
    #     global_pool=args.gp,
    #     bn_tf=args.bn_tf,
    #     bn_momentum=args.bn_momentum,
    #     bn_eps=args.bn_eps,
    #     checkpoint_path=args.initial_checkpoint)

    if args.local_rank == 0:
        logging.info('Model %s created, param count: %d' %
                     (args.model, sum([m.numel()
                                       for m in model.parameters()])))

    model.cuda()
    optimizer = create_optimizer(args, model)
    use_amp = False
    if has_apex and args.amp:
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        use_amp = True
    if args.local_rank == 0:
        logging.info('NVIDIA APEX {}. AMP {}.'.format(
            'installed' if has_apex else 'not installed',
            'on' if use_amp else 'off'))

    # optionally resume from a checkpoint
    resume_state = {}
    resume_epoch = None
    if args.resume:
        resume_state, resume_epoch = resume_checkpoint(_unwrap_bench(model),
                                                       args.resume)
    if resume_state and not args.no_resume_opt:
        if 'optimizer' in resume_state:
            if args.local_rank == 0:
                logging.info('Restoring Optimizer state from checkpoint')
            optimizer.load_state_dict(resume_state['optimizer'])
        if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__:
            if args.local_rank == 0:
                logging.info('Restoring NVIDIA AMP state from checkpoint')
            amp.load_state_dict(resume_state['amp'])
    del resume_state

    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
        model_ema = ModelEma(model,
                             decay=args.model_ema_decay,
                             device='cpu' if args.model_ema_force_cpu else '')
        #resume=args.resume)  # FIXME bit of a mess with bench
        if args.resume:
            load_checkpoint(_unwrap_bench(model_ema),
                            args.resume,
                            use_ema=True)

    if args.distributed:
        if args.sync_bn:
            try:
                if has_apex:
                    model = convert_syncbn_model(model)
                else:
                    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                        model)
                if args.local_rank == 0:
                    logging.info(
                        'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
                        'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.'
                    )
            except Exception as e:
                logging.error(
                    'Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1'
                )
        if has_apex:
            model = DDP(model, delay_allreduce=True)
        else:
            if args.local_rank == 0:
                logging.info(
                    "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP."
                )
            model = DDP(model,
                        device_ids=[args.local_rank
                                    ])  # can use device str in Torch >= 1.1
        # NOTE: EMA model does not need to be wrapped by DDP

    lr_scheduler, num_epochs = create_scheduler(args, optimizer)
    start_epoch = 0
    if args.start_epoch is not None:
        # a specified start_epoch will always override the resume epoch
        start_epoch = args.start_epoch
    elif resume_epoch is not None:
        start_epoch = resume_epoch
    if lr_scheduler is not None and start_epoch > 0:
        lr_scheduler.step(start_epoch)

    if args.local_rank == 0:
        logging.info('Scheduled epochs: {}'.format(num_epochs))

    train_anno_set = 'train2017'
    train_annotation_path = os.path.join(args.data, 'annotations',
                                         f'instances_{train_anno_set}.json')
    train_image_dir = train_anno_set
    dataset_train = CocoDetection(os.path.join(args.data, train_image_dir),
                                  train_annotation_path)

    # FIXME cutmix/mixup worth investigating?
    # collate_fn = None
    # if args.prefetcher and args.mixup > 0:
    #     collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes)

    loader_train = create_loader(
        dataset_train,
        input_size=config.image_size,
        batch_size=args.batch_size,
        is_training=True,
        use_prefetcher=args.prefetcher,
        #re_prob=args.reprob,  # FIXME add back various augmentations
        #re_mode=args.remode,
        #re_count=args.recount,
        #re_split=args.resplit,
        #color_jitter=args.color_jitter,
        #auto_augment=args.aa,
        interpolation=args.train_interpolation,
        #mean=data_config['mean'],
        #std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        #collate_fn=collate_fn,
        pin_mem=args.pin_mem,
    )

    train_anno_set = 'val2017'
    train_annotation_path = os.path.join(args.data, 'annotations',
                                         f'instances_{train_anno_set}.json')
    train_image_dir = train_anno_set
    dataset_eval = CocoDetection(os.path.join(args.data, train_image_dir),
                                 train_annotation_path)

    loader_eval = create_loader(
        dataset_eval,
        input_size=config.image_size,
        batch_size=args.validation_batch_size_multiplier * args.batch_size,
        is_training=False,
        use_prefetcher=args.prefetcher,
        interpolation=args.interpolation,
        #mean=data_config['mean'],
        #std=data_config['std'],
        num_workers=args.workers,
        #distributed=args.distributed,
        pin_mem=args.pin_mem,
    )

    eval_metric = args.eval_metric
    best_metric = None
    best_epoch = None
    saver = None
    output_dir = ''
    if args.local_rank == 0:
        output_base = args.output if args.output else './output'
        exp_name = '-'.join(
            [datetime.now().strftime("%Y%m%d-%H%M%S"), args.model])
        output_dir = get_outdir(output_base, 'train', exp_name)
        decreasing = True if eval_metric == 'loss' else False
        saver = CheckpointSaver(checkpoint_dir=output_dir,
                                decreasing=decreasing)
        with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
            f.write(args_text)

    try:
        for epoch in range(start_epoch, num_epochs):
            if args.distributed:
                loader_train.sampler.set_epoch(epoch)

            train_metrics = train_epoch(epoch,
                                        model,
                                        loader_train,
                                        optimizer,
                                        args,
                                        lr_scheduler=lr_scheduler,
                                        saver=saver,
                                        output_dir=output_dir,
                                        use_amp=use_amp,
                                        model_ema=model_ema)

            if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                if args.local_rank == 0:
                    logging.info(
                        "Distributing BatchNorm running means and vars")
                distribute_bn(model, args.world_size, args.dist_bn == 'reduce')

            eval_metrics = validate(model, loader_eval, args)

            if model_ema is not None and not args.model_ema_force_cpu:
                if args.distributed and args.dist_bn in ('broadcast',
                                                         'reduce'):
                    distribute_bn(model_ema, args.world_size,
                                  args.dist_bn == 'reduce')

                ema_eval_metrics = validate(model_ema.ema,
                                            loader_eval,
                                            args,
                                            log_suffix=' (EMA)')
                eval_metrics = ema_eval_metrics

            if lr_scheduler is not None:
                # step LR for next epoch
                lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])

            update_summary(epoch,
                           train_metrics,
                           eval_metrics,
                           os.path.join(output_dir, 'summary.csv'),
                           write_header=best_metric is None)

            if saver is not None:
                # save proper checkpoint with eval metric
                save_metric = eval_metrics[eval_metric]
                best_metric, best_epoch = saver.save_checkpoint(
                    _unwrap_bench(model),
                    optimizer,
                    args,
                    epoch=epoch,
                    model_ema=_unwrap_bench(model_ema),
                    metric=save_metric,
                    use_amp=use_amp)

    except KeyboardInterrupt:
        pass
    if best_metric is not None:
        logging.info('*** Best metric: {0} (epoch {1})'.format(
            best_metric, best_epoch))
def train_and_fit(args):

    if args.fp16:
        from apex import amp
    else:
        amp = None

    cuda = torch.cuda.is_available()

    train_loader = load_dataloaders(args)
    train_len = len(train_loader)
    logger.info("Loaded %d pre-training samples." % train_len)

    if args.model_no == 0:
        from .model.BERT.modeling_bert import BertModel as Model
        model = args.model_size  #'bert-base-uncased'
        lower_case = True
        model_name = 'BERT'
        net = Model.from_pretrained(model, force_download=False, \
                                model_size=args.model_size)
    elif args.model_no == 1:
        from .model.ALBERT.modeling_albert import AlbertModel as Model
        model = args.model_size  #'albert-base-v2'
        lower_case = False
        model_name = 'ALBERT'
        net = Model.from_pretrained(model, force_download=False, \
                                model_size=args.model_size)
    elif args.model_no == 2:  # BioBert
        from .model.BERT.modeling_bert import BertModel, BertConfig
        model = 'bert-base-uncased'
        lower_case = False
        model_name = 'BioBERT'
        config = BertConfig.from_pretrained(
            os.getcwd() +
            '/additional_models/biobert_v1.1_pubmed/bert_config.json')
        net = BertModel.from_pretrained(pretrained_model_name_or_path=os.getcwd() + '/additional_models/biobert_v1.1_pubmed/biobert_v1.1_pubmed.bin',
                                          config=config,
                                          force_download=False, \
                                          model_size='bert-base-uncased')

    tokenizer = load_pickle("%s_tokenizer.pkl" % model_name)
    net.resize_token_embeddings(len(tokenizer))
    e1_id = tokenizer.convert_tokens_to_ids('[E1]')
    e2_id = tokenizer.convert_tokens_to_ids('[E2]')
    assert e1_id != e2_id != 1

    if cuda:
        net.cuda()

    if args.freeze == 1:
        logger.info("FREEZING MOST HIDDEN LAYERS...")
        if args.model_no == 0:
            unfrozen_layers = ["classifier", "pooler", "encoder.layer.11", "encoder.layer.10",\
                               "encoder.layer.9", "blanks_linear", "lm_linear", "cls"]
        elif args.model_no == 1:
            unfrozen_layers = ["classifier", "pooler", "embeddings", "attention",\
                               "blanks_linear", "lm_linear", "cls",\
                               "albert_layer_groups.0.albert_layers.0.ffn"]

        for name, param in net.named_parameters():
            if not any([layer in name for layer in unfrozen_layers]):
                print("[FROZE]: %s" % name)
                param.requires_grad = False
            else:
                print("[FREE]: %s" % name)
                param.requires_grad = True

    criterion = Two_Headed_Loss(lm_ignore_idx=tokenizer.pad_token_id,
                                use_logits=True,
                                normalize=False)
    optimizer = optim.Adam([{"params": net.parameters(), "lr": args.lr}])

    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[2,4,6,8,12,15,18,20,22,\
                                                                      24,26,30], gamma=0.8)

    start_epoch, best_pred, amp_checkpoint = load_state(net,
                                                        optimizer,
                                                        scheduler,
                                                        args,
                                                        load_best=False)

    if (args.fp16) and (amp is not None):
        logger.info("Using fp16...")
        net, optimizer = amp.initialize(net, optimizer, opt_level='O2')
        if amp_checkpoint is not None:
            amp.load_state_dict(amp_checkpoint)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[2,4,6,8,12,15,18,20,22,\
                                                                          24,26,30], gamma=0.8)

    losses_per_epoch, accuracy_per_epoch = load_results(args.model_no)

    logger.info("Starting training process...")
    pad_id = tokenizer.pad_token_id
    mask_id = tokenizer.mask_token_id
    update_size = len(train_loader) // 10
    for epoch in range(start_epoch, args.num_epochs):
        start_time = time.time()
        net.train()
        total_loss = 0.0
        losses_per_batch = []
        total_acc = 0.0
        lm_accuracy_per_batch = []
        for i, data in enumerate(train_loader, 0):
            x, masked_for_pred, e1_e2_start, _, blank_labels, _, _, _, _, _ = data
            masked_for_pred1 = masked_for_pred
            masked_for_pred = masked_for_pred[(masked_for_pred != pad_id)]
            if masked_for_pred.shape[0] == 0:
                print('Empty dataset, skipping...')
                continue
            attention_mask = (x != pad_id).float()
            token_type_ids = torch.zeros((x.shape[0], x.shape[1])).long()

            if cuda:
                x = x.cuda()
                masked_for_pred = masked_for_pred.cuda()
                attention_mask = attention_mask.cuda()
                token_type_ids = token_type_ids.cuda()

            blanks_logits, lm_logits = net(x, token_type_ids=token_type_ids, attention_mask=attention_mask, Q=None,\
                          e1_e2_start=e1_e2_start)
            lm_logits = lm_logits[(x == mask_id)]

            #return lm_logits, blanks_logits, x, e1_e2_start, masked_for_pred, masked_for_pred1, blank_labels, tokenizer # for debugging now
            if (i % update_size) == (update_size - 1):
                verbose = True
            else:
                verbose = False

            loss = criterion(lm_logits,
                             blanks_logits,
                             masked_for_pred,
                             blank_labels,
                             verbose=verbose)
            loss = loss / args.gradient_acc_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            if args.fp16:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    amp.master_params(optimizer), args.max_norm)
            else:
                grad_norm = clip_grad_norm_(net.parameters(), args.max_norm)

            if (i % args.gradient_acc_steps) == 0:
                optimizer.step()
                optimizer.zero_grad()

            total_loss += loss.item()
            total_acc += evaluate_(lm_logits, blanks_logits, masked_for_pred, blank_labels, \
                                   tokenizer, print_=False)[0]

            if (i % update_size) == (update_size - 1):
                losses_per_batch.append(args.gradient_acc_steps * total_loss /
                                        update_size)
                lm_accuracy_per_batch.append(total_acc / update_size)
                print(
                    '[Epoch: %d, %5d/ %d points] total loss, lm accuracy per batch: %.3f, %.3f'
                    % (epoch + 1, (i + 1), train_len, losses_per_batch[-1],
                       lm_accuracy_per_batch[-1]))
                total_loss = 0.0
                total_acc = 0.0
                logger.info("Last batch samples (pos, neg): %d, %d" % ((blank_labels.squeeze() == 1).sum().item(),\
                                                                    (blank_labels.squeeze() == 0).sum().item()))

        scheduler.step()
        losses_per_epoch.append(sum(losses_per_batch) / len(losses_per_batch))
        accuracy_per_epoch.append(
            sum(lm_accuracy_per_batch) / len(lm_accuracy_per_batch))
        print("Epoch finished, took %.2f seconds." %
              (time.time() - start_time))
        print("Losses at Epoch %d: %.7f" % (epoch + 1, losses_per_epoch[-1]))
        print("Accuracy at Epoch %d: %.7f" %
              (epoch + 1, accuracy_per_epoch[-1]))

        if accuracy_per_epoch[-1] > best_pred:
            best_pred = accuracy_per_epoch[-1]
            torch.save({
                    'epoch': epoch + 1,\
                    'state_dict': net.state_dict(),\
                    'best_acc': accuracy_per_epoch[-1],\
                    'optimizer' : optimizer.state_dict(),\
                    'scheduler' : scheduler.state_dict(),\
                    'amp': amp.state_dict() if amp is not None else amp
                }, os.path.join(os.getcwd() + "/src/data/" , "test_model_best_%d.pth.tar" % args.model_no))

        if (epoch % 1) == 0:
            save_as_pickle("test_losses_per_epoch_%d.pkl" % args.model_no,
                           losses_per_epoch)
            save_as_pickle("test_accuracy_per_epoch_%d.pkl" % args.model_no,
                           accuracy_per_epoch)
            torch.save({
                    'epoch': epoch + 1,\
                    'state_dict': net.state_dict(),\
                    'best_acc': accuracy_per_epoch[-1],\
                    'optimizer' : optimizer.state_dict(),\
                    'scheduler' : scheduler.state_dict(),\
                    'amp': amp.state_dict() if amp is not None else amp
                }, os.path.join(os.getcwd() + "/src/data/" , "test_checkpoint_%d.pth.tar" % args.model_no))

    logger.info("Finished Training!")
    fig = plt.figure(figsize=(20, 20))
    ax = fig.add_subplot(111)
    ax.scatter([e for e in range(len(losses_per_epoch))], losses_per_epoch)
    ax.tick_params(axis="both", length=2, width=1, labelsize=14)
    ax.set_xlabel("Epoch", fontsize=22)
    ax.set_ylabel("Training Loss per batch", fontsize=22)
    ax.set_title("Training Loss vs Epoch", fontsize=32)
    plt.savefig(
        os.path.join(os.getcwd() + "/src/data/",
                     "loss_vs_epoch_%d.png" % args.model_no))

    fig2 = plt.figure(figsize=(20, 20))
    ax2 = fig2.add_subplot(111)
    ax2.scatter([e for e in range(len(accuracy_per_epoch))],
                accuracy_per_epoch)
    ax2.tick_params(axis="both", length=2, width=1, labelsize=14)
    ax2.set_xlabel("Epoch", fontsize=22)
    ax2.set_ylabel("Test Masked LM Accuracy", fontsize=22)
    ax2.set_title("Test Masked LM Accuracy vs Epoch", fontsize=32)
    plt.savefig(
        os.path.join(os.getcwd() + "/src/data/",
                     "accuracy_vs_epoch_%d.png" % args.model_no))

    return net
    def restore_training_state(self, checkpoint):
        """
        Restore trainer state.
        Model will get its change to update
        :param checkpoint:
        :return:
        """
        # validation
        if 'optimizer_states' not in checkpoint or 'lr_schedulers' not in checkpoint:
            raise KeyError(
                'Trying to restore training state but checkpoint contains only the model.'
                ' This is probably due to `ModelCheckpoint.save_weights_only` being set to `True`.'
            )

        if any([key in checkpoint for key in DEPRECATED_CHECKPOINT_KEYS]):
            raise ValueError(
                "The checkpoint you're attempting to load follows an"
                " outdated schema. You can upgrade to the current schema by running"
                " `python -m pytorch_lightning.utilities.upgrade_checkpoint --file model.ckpt`"
                " where `model.ckpt` is your checkpoint file.")

        # restore amp scaling
        if self.trainer.amp_backend == AMPType.NATIVE and 'native_amp_scaling_state' in checkpoint:
            self.trainer.scaler.load_state_dict(
                checkpoint['native_amp_scaling_state'])
        elif self.trainer.amp_backend == AMPType.APEX and 'amp_scaling_state' in checkpoint:
            amp.load_state_dict(checkpoint['amp_scaling_state'])

        # restore callback states
        self.trainer.on_load_checkpoint(checkpoint)

        self.trainer.global_step = checkpoint['global_step']
        self.trainer.current_epoch = checkpoint['epoch']

        # crash if max_epochs is lower then the current epoch from the checkpoint
        if self.trainer.current_epoch > self.trainer.max_epochs:
            m = f"""
            you restored a checkpoint with current_epoch={self.trainer.current_epoch}
            but the Trainer(max_epochs={self.trainer.max_epochs})
            """
            raise MisconfigurationException(m)

        # Division deals with global step stepping once per accumulated batch
        # Inequality deals with different global step for odd vs even num_training_batches
        n_accum = 1 if self.trainer.accumulate_grad_batches is None else self.trainer.accumulate_grad_batches
        expected_steps = self.trainer.num_training_batches / n_accum
        if self.trainer.num_training_batches != 0 and self.trainer.global_step % expected_steps > 1:
            rank_zero_warn(
                "You're resuming from a checkpoint that ended mid-epoch. "
                "This can cause unreliable results if further training is done, "
                "consider using an end of epoch checkpoint. ")

        # restore the optimizers
        optimizer_states = checkpoint['optimizer_states']
        for optimizer, opt_state in zip(self.trainer.optimizers,
                                        optimizer_states):
            optimizer.load_state_dict(opt_state)

            # move optimizer to GPU 1 weight at a time
            # avoids OOM
            if self.trainer.root_gpu is not None:
                for state in optimizer.state.values():
                    for k, v in state.items():
                        if isinstance(v, torch.Tensor):
                            state[k] = v.cuda(self.trainer.root_gpu)

        # restore the lr schedulers
        lr_schedulers = checkpoint['lr_schedulers']
        for scheduler, lrs_state in zip(self.trainer.lr_schedulers,
                                        lr_schedulers):
            scheduler['scheduler'].load_state_dict(lrs_state)
예제 #9
0
def main(argv=None):
    # Load Config
    with open(FLAGS.config_path, 'r') as f:
        config = yaml.load(f, Loader=yaml.SafeLoader)

    print('Configs overview:')
    print(json.dumps(config, indent=2))

    # define SummaryWriter for tensorboard
    log_dir = config['log']['log_dir'] + ';' + str(DATETIME).replace(' ', ';')
    writer = SummaryWriter(log_dir=os.path.join(log_dir, 'tensorboard'))
    writer.add_text('config', json.dumps(config, indent=2).replace(' ', ' ').replace('\n', '  \n'))

    # define dataloader
    dataset_options = config['dataset']
    dataloader_options = config['dataloader']
    if dataset_options['mode'] == 'dali':
        train_data_loader = DALIDataLoader(csv_path = dataset_options['train_csv_path'],
                                           data_path = dataset_options['train_image_path'],
                                           batch_size=dataloader_options['batch_size'], valid=False,
                                           nfold=dataset_options['validation']['nfold']) 
        valid_data_loader = DALIDataLoader(csv_path = dataset_options['train_csv_path'],
                                           data_path = dataset_options['train_image_path'],
                                           batch_size=dataloader_options['batch_size'], valid=True,
                                           nfold=dataset_options['validation']['nfold']) 
    else:
        train_data_loader, valid_data_loader = load_data(data_path=dataset_options['train_image_path'],
                                                               csv_path=dataset_options['train_csv_path'],
                                                               batch_size=dataloader_options['batch_size'],
                                                               train_trans_list=dataloader_options['data_augment']['train_trans'],
                                                               valid_trans_list=dataloader_options['data_augment']['valid_trans'],
                                                               trans_mode=dataloader_options['data_augment']['mode'],
                                                               trans_lib=dataloader_options['data_augment']['lib'],
                                                               valid_mode=dataset_options['validation']['mode'],
                                                               nfold=dataset_options['validation']['nfold'],
                                                               mode=dataset_options['mode'],
                                                               lmdb_path=dataset_options['lmdb_path'])
    
    # define model
    model_options = config['model']
    model = get_model_from_name(model_name=model_options['model_name'],
                                image_size=model_options['image_size'],
                                num_classes=model_options['num_classes'],
                                pretrained=model_options['pretrained'])
    # for tensorboard and torchsummary
    image_size = model_options['image_size']
    writer.add_graph(model, torch.zeros(image_size).unsqueeze(dim=0))
    model = model.to(DEVICE)
    if 'torchsummary' in sys.modules:
        tsummary(model, tuple(image_size))


    # define loss
    loss_options = config['loss']
    loss_params = loss_options['loss_params']
    loss_fn = get_loss_from_name(loss_name=loss_options['loss_name'],
                                 loss_params=loss_params,
                                 mode=loss_options['mode'],
                                 lib=loss_options['lib'])

    # define metric
    metric_dict = get_metrics(config['metric'])

    train_options = config['train']
    optimizer_list = []
    scheduler_list = []
    for phase in range(len(train_options)):
        # define optimizer
        opt_options = train_options[phase]['optimizer']
        opt_params = opt_options['opt_params']
        optimizer_list.append(get_optimizer_from_name(opt_name=opt_options['opt_name'], 
                                                      model=model,
                                                      target=opt_options['target'], 
                                                      opt_params=opt_params,
                                                      mode=opt_options['mode'],
                                                      lib=opt_options['lib']))
        # define scheduler
        scheduler_options = train_options[phase]['scheduler']
        scheduler_params = scheduler_options['scheduler_params']
        scheduler_list.append(get_scheduler_from_name(scheduler_name=scheduler_options['scheduler_name'], optimizer=optimizer_list[phase],
                                                      scheduler_params=scheduler_params,
                                                      mode=scheduler_options['mode'],
                                                      lib=scheduler_options['lib']))
    # Mixed Precision
    amp_options = config['amp']
    use_amp = ('apex' in sys.modules) and amp_options['use_amp']
    if use_amp:
        model, optimizer_list = amp.initialize(model, optimizer_list, opt_level=amp_options['opt_level'], num_losses=1)
    if model_options['dataparallel']:
        model = nn.DataParallel(model)

    # for Restart
    checkpoint = torch.load(model_options['checkpoint'])
    if 'model' in model_options['restart']:
        model.load_state_dict(checkpoint['model'])
    if 'optimizer' in model_options['restart']:
        optimizer_list[0].load_state_dict(checkpoint['optimizer'])
    if 'amp' in model_options['restart']:
        amp.load_state_dict(checkpoint['amp'])

    # for save params
    save_options = config['log']
    save_dir = os.path.join(log_dir, 'checkpoints')
    os.mkdir(save_dir)
    save_keys = ['model']
    save_target = [model]
    if use_amp:
        save_keys.append('amp')
        save_target.append(amp)
    save_keys.append('optimizer')

    global_epoch = 0
    best_val_score = 0
    for phase in range(len(train_options)):
        print('Start Train phase: {}'.format(phase)) 
        
        for e in range(global_epoch + 1, global_epoch + train_options[phase]['epoch'] + 1):
            # Training
            train_scores = train_loop(model, train_data_loader, loss_fn, optimizer_list[phase], metric_dict=metric_dict, use_amp=use_amp)
            print_save_scores(train_scores, e, 'Train', writer, display=False)
    
            # Validation
            valid_scores = valid_loop(model, valid_data_loader, loss_fn, metric_dict=metric_dict)
            print_save_scores(valid_scores, e, 'Valid', writer, display=False)
    
            # Update Scheduler
            scheduler_list[phase].step()
            lrs = {}
            for idx, lr in enumerate(scheduler_list[phase].get_lr()):
                lrs['group_{}'.format(idx)] = lr
            writer.add_scalars('LearningRate', lrs, e)
    
            # Summary
            print_save_scores_summary(train_scores, valid_scores, e, writer, display=True)

            # Save Params
            if best_val_score < valid_scores[save_options['save_best_target']] and save_options['save_best_val'] and save_options['save_skip_epoch'] < e:
                best_val_score = valid_scores[save_options['save_best_target']]
                save_params(keys=save_keys, targets=save_target + [optimizer_list[phase]], save_path=os.path.join(save_dir, save_options['save_name'] + '-' +  str(e) + '-' + str(best_val_score) +'.pth'))
            elif e%save_options['save_interval'] == 0  and save_options['save_skip_epoch'] < e:
                save_params(keys=save_keys, targets=save_target + [optimizer_list[phase]], save_path=os.path.join(save_dir, save_options['save_name'] + '-' +  str(e) + '.pth'))

        global_epoch = global_epoch + train_options[phase]['epoch']     
    if save_options['save_final']:
        save_params(keys=save_keys, targets=save_target + [optimizer_list[-1]], save_path=os.path.join(save_dir, save_options['save_name'] + '-final.pth'))
    writer.close()
예제 #10
0
def main():
    global args
    setup_default_logging()
    args, args_text = _parse_args()

    args.prefetcher = not args.no_prefetcher
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
        if args.distributed and args.num_gpu > 1:
            logging.warning(
                'Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.'
            )
            args.num_gpu = 1

    args.device = 'cuda:0'
    args.world_size = 1
    args.rank = 0  # global rank
    if args.distributed:
        args.num_gpu = 1
        args.device = 'cuda:%d' % args.local_rank
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()
        args.rank = torch.distributed.get_rank()
    assert args.rank >= 0

    if args.distributed:
        logging.info(
            'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
            % (args.rank, args.world_size))
    else:
        logging.info('Training with a single process on %d GPUs.' %
                     args.num_gpu)

    torch.manual_seed(args.seed + args.rank)
    np.random.seed(args.seed + args.rank)

    model = create_model(args.model,
                         pretrained=args.pretrained,
                         num_classes=args.num_classes,
                         drop_rate=args.drop,
                         global_pool=args.gp,
                         bn_tf=args.bn_tf,
                         bn_momentum=args.bn_momentum,
                         bn_eps=args.bn_eps,
                         checkpoint_path=args.initial_checkpoint)

    if args.binarizable:
        Model_binary_patch(model)

    if args.local_rank == 0:
        logging.info('Model %s created, param count: %d' %
                     (args.model, sum([m.numel()
                                       for m in model.parameters()])))

    data_config = resolve_data_config(vars(args),
                                      model=model,
                                      verbose=args.local_rank == 0)

    if args.num_gpu > 1:
        if args.amp:
            logging.warning(
                'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.'
            )
            args.amp = False
        model = nn.DataParallel(model,
                                device_ids=list(range(args.num_gpu))).cuda()

    else:
        model.cuda()

    optimizer = create_optimizer(args, model)

    use_amp = False
    if has_apex and args.amp:
        print('Using amp with --opt-level {}.'.format(args.opt_level))
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.opt_level)
        use_amp = True
    else:
        print('Do NOT use amp.')
    if args.local_rank == 0:
        logging.info('NVIDIA APEX {}. AMP {}.'.format(
            'installed' if has_apex else 'not installed',
            'on' if use_amp else 'off'))

    # optionally resume from a checkpoint
    resume_state = {}
    resume_epoch = None
    if args.resume:
        resume_state, resume_epoch = resume_checkpoint(model, args.resume)
    if resume_state and not args.no_resume_opt:
        if 'optimizer' in resume_state:
            if args.local_rank == 0:
                logging.info('Restoring Optimizer state from checkpoint')
            optimizer.load_state_dict(resume_state['optimizer'])
        if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__:
            if args.local_rank == 0:
                logging.info('Restoring NVIDIA AMP state from checkpoint')
            amp.load_state_dict(resume_state['amp'])
    resume_state = None

    if args.freeze_binary:
        Model_freeze_binary(model)

    if args.distributed:
        if args.sync_bn:
            try:
                if has_apex:
                    model = convert_syncbn_model(model)
                else:
                    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                        model)
                if args.local_rank == 0:
                    logging.info(
                        'Converted model to use Synchronized BatchNorm.')
            except Exception as e:
                logging.error(
                    'Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1'
                )
        if has_apex:
            model = DDP(model, delay_allreduce=True)
        else:
            if args.local_rank == 0:
                logging.info(
                    "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP."
                )
            model = DDP(model,
                        device_ids=[args.local_rank
                                    ])  # can use device str in Torch >= 1.1
        # NOTE: EMA model does not need to be wrapped by DDP

    lr_scheduler, num_epochs = create_scheduler(args, optimizer)
    print(num_epochs)
    # start_epoch = 0 #
    if args.start_epoch is not None:
        # a specified start_epoch will always override the resume epoch
        start_epoch = args.start_epoch
    elif resume_epoch is not None:
        start_epoch = resume_epoch
    if args.reset_lr_scheduler is not None:
        lr_scheduler.base_values = len(
            lr_scheduler.base_values) * [args.reset_lr_scheduler]
        lr_scheduler.step(start_epoch)

    if lr_scheduler is not None and start_epoch > 0:
        lr_scheduler.step(start_epoch)

    if args.local_rank == 0:
        logging.info('Scheduled epochs: {}'.format(num_epochs))

    # Using pruner to get sparse weights
    if args.prune:
        pruner = Pruner_mixed(model, 0, 100, args.pruner)
    else:
        pruner = None

    dataset_train = torchvision.datasets.CIFAR100(root='~/Downloads/CIFAR100',
                                                  train=True,
                                                  download=True)

    collate_fn = None
    if args.prefetcher and args.mixup > 0:
        collate_fn = FastCollateMixup(args.mixup, args.smoothing,
                                      args.num_classes)

    loader_train = create_loader_CIFAR100(
        dataset_train,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        is_training=True,
        use_prefetcher=args.prefetcher,
        rand_erase_prob=args.reprob,
        rand_erase_mode=args.remode,
        rand_erase_count=args.recount,
        color_jitter=args.color_jitter,
        auto_augment=args.aa,
        interpolation='random',
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        collate_fn=collate_fn,
        is_clean_data=args.clean_train,
    )

    dataset_eval = torchvision.datasets.CIFAR100(root='~/Downloads/CIFAR100',
                                                 train=False,
                                                 download=True)

    loader_eval = create_loader_CIFAR100(
        dataset_eval,
        input_size=data_config['input_size'],
        batch_size=4 * args.batch_size,
        is_training=False,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
    )

    if args.mixup > 0.:
        # smoothing is handled with mixup label transform
        train_loss_fn = SoftTargetCrossEntropy(
            multiplier=args.softmax_multiplier).cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    elif args.smoothing:
        train_loss_fn = LabelSmoothingCrossEntropy(
            smoothing=args.smoothing).cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    else:
        train_loss_fn = nn.CrossEntropyLoss().cuda()
        validate_loss_fn = train_loss_fn

    eval_metric = args.eval_metric
    best_metric = None
    best_epoch = None
    saver = None
    saver_last_10_epochs = None
    output_dir = ''
    if args.local_rank == 0:
        output_base = args.output if args.output else './output'
        exp_name = '-'.join([
            datetime.now().strftime("%Y%m%d-%H%M%S"), args.model,
            str(data_config['input_size'][-1])
        ])
        output_dir = get_outdir(output_base, 'train', exp_name)
        decreasing = True if eval_metric == 'loss' else False
        os.makedirs(output_dir + '/Top')
        os.makedirs(output_dir + '/Last')
        saver = CheckpointSaver(
            checkpoint_dir=output_dir + '/Top',
            decreasing=decreasing,
            max_history=10)  # Save the results of the top 10 epochs
        saver_last_10_epochs = CheckpointSaver(
            checkpoint_dir=output_dir + '/Last',
            decreasing=decreasing,
            max_history=10)  # Save the results of the last 10 epochs
        with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
            f.write(args_text)
            f.write('==============================\n')
            f.write(model.__str__())
            # if pruner:
            #     f.write('\n Sparsity \n')
            #     #f.write(pruner.threshold_dict.__str__())
            #     f.write('\n pruner.start_epoch={}, pruner.end_epoch={}'.format(pruner.start_epoch, pruner.end_epoch))

    tensorboard_writer = SummaryWriter(output_dir)

    try:
        for epoch in range(start_epoch, num_epochs):

            global alpha
            alpha = get_alpha(epoch, args)

            if args.distributed:
                loader_train.sampler.set_epoch(epoch)

            if pruner:
                pruner.on_epoch_begin(epoch)  # pruning

            train_metrics = train_epoch(epoch,
                                        model,
                                        loader_train,
                                        optimizer,
                                        train_loss_fn,
                                        args,
                                        lr_scheduler=lr_scheduler,
                                        saver=saver,
                                        output_dir=output_dir,
                                        use_amp=use_amp,
                                        tensorboard_writer=tensorboard_writer,
                                        pruner=pruner)

            if pruner:
                pruner.print_statistics()

            eval_metrics = validate(model,
                                    loader_eval,
                                    validate_loss_fn,
                                    args,
                                    tensorboard_writer=tensorboard_writer,
                                    epoch=epoch)

            if lr_scheduler is not None:
                # step LR for next epoch
                lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])

            update_summary(epoch,
                           train_metrics,
                           eval_metrics,
                           os.path.join(output_dir, 'summary.csv'),
                           write_header=best_metric is None)

            if saver is not None:
                # save proper checkpoint with eval metric
                save_metric = eval_metrics[eval_metric]
                best_metric, best_epoch = saver.save_checkpoint(
                    model,
                    optimizer,
                    args,
                    epoch=epoch,
                    metric=save_metric,
                    use_amp=use_amp)
            if saver_last_10_epochs is not None:
                # save the checkpoint in the last 5 epochs
                _, _ = saver_last_10_epochs.save_checkpoint(model,
                                                            optimizer,
                                                            args,
                                                            epoch=epoch,
                                                            metric=epoch,
                                                            use_amp=use_amp)

    except KeyboardInterrupt:
        pass
    if best_metric is not None:
        logging.info('*** Best metric: {0} (epoch {1})'.format(
            best_metric, best_epoch))
예제 #11
0
def main(args):
    # get device
    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    # build model
    model = STTR(args).to(device)
    print_param(model)

    # set learning rate
    param_dicts = [
        {"params": [p for n, p in model.named_parameters() if
                    "backbone" not in n and "regression" not in n and p.requires_grad]},
        {
            "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
            "lr": args.lr_backbone,
        },
        {
            "params": [p for n, p in model.named_parameters() if "regression" in n and p.requires_grad],
            "lr": args.lr_regression,
        },
    ]

    # define optimizer and learning rate scheduler
    optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, weight_decay=args.weight_decay)
    lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_decay_rate)

    # mixed precision training
    if args.apex:
        from apex import amp
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
    else:
        amp = None

    # load checkpoint if provided
    prev_best = np.inf
    if args.resume != '':
        if not os.path.isfile(args.resume):
            raise RuntimeError(f"=> no checkpoint found at '{args.resume}'")
        checkpoint = torch.load(args.resume)

        pretrained_dict = checkpoint['state_dict']
        missing, unexpected = model.load_state_dict(pretrained_dict, strict=False)
        # check missing and unexpected keys
        if len(missing) > 0:
            print("Missing keys: ", ','.join(missing))
            raise Exception("Missing keys.")
        unexpected = [k for k in unexpected if 'running_mean' not in k and 'running_var' not in k]  # skip bn params
        if len(unexpected) > 0:
            print("Unexpected keys: ", ','.join(unexpected))
            raise Exception("Unexpected keys.")
        print("Pre-trained model successfully loaded.")

        # if not ft/inference/eval, load states for optimizer, lr_scheduler, amp and prev best
        if not (args.ft or args.inference or args.eval):
            args.start_epoch = checkpoint['epoch'] + 1
            optimizer.load_state_dict(checkpoint['optimizer'])
            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            prev_best = checkpoint['best_pred']
            if args.apex:
                amp.load_state_dict(checkpoint['amp'])
            print("Pre-trained optimizer, lr scheduler and stats successfully loaded.")

    # inference
    if args.inference:
        print("Start inference")
        _, _, data_loader = build_data_loader(args)
        inference(model, data_loader, device, args.downsample)

        return

    # initiate saver and logger
    checkpoint_saver = Saver(args)
    summary_writer = TensorboardSummary(checkpoint_saver.experiment_dir)

    # build dataloader
    data_loader_train, data_loader_val, _ = build_data_loader(args)

    # build loss criterion
    criterion = build_criterion(args)

    # set downsample rate
    set_downsample(args)

    # eval
    if args.eval:
        print("Start evaluation")
        evaluate(model, criterion, data_loader_val, device, 0, summary_writer, True)
        return

    # train
    print("Start training")
    for epoch in range(args.start_epoch, args.epochs):
        # train
        print("Epoch: %d" % epoch)
        train_one_epoch(model, data_loader_train, optimizer, criterion, device, epoch, summary_writer,
                        args.clip_max_norm, amp)

        # step lr if not pretraining
        if not args.pre_train:
            lr_scheduler.step()
            print("current learning rate", lr_scheduler.get_lr())

        # empty cache
        torch.cuda.empty_cache()

        # save if pretrain, save every 50 epochs
        if args.pre_train or epoch % 50 == 0:
            save_checkpoint(epoch, model, optimizer, lr_scheduler, prev_best, checkpoint_saver, False, amp)

        # validate
        eval_stats = evaluate(model, criterion, data_loader_val, device, epoch, summary_writer, False)
        # save if best
        if prev_best > eval_stats['epe'] and 0.5 > eval_stats['px_error_rate']:
            save_checkpoint(epoch, model, optimizer, lr_scheduler, prev_best, checkpoint_saver, True, amp)

    # save final model
    save_checkpoint(epoch, model, optimizer, lr_scheduler, prev_best, checkpoint_saver, False, amp)

    return
예제 #12
0
def main():

    # parse the args
    args = parse_option()

    # set the loader
    train_loader, n_data = get_train_loader(
        args)  # change to this if testing on cifar as baseline

    # set the model
    model, contrast, criterion_ab, criterion_l = set_model(args, n_data)

    # set the optimizer
    optimizer = set_optimizer(args, model)

    # set mixed precision
    if args.amp:
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.opt_level)

    # optionally resume from a checkpoint
    args.start_epoch = 1
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cpu')
            args.start_epoch = checkpoint['epoch'] + 1
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            contrast.load_state_dict(checkpoint['contrast'])
            if args.amp and checkpoint['opt'].amp:
                print('==> resuming amp state_dict')
                amp.load_state_dict(checkpoint['amp'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            del checkpoint
            torch.cuda.empty_cache()
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # tensorboard
    logger = tb_logger.Logger(logdir=args.tb_folder, flush_secs=2)

    # routine
    for epoch in range(args.start_epoch, args.epochs + 1):

        adjust_learning_rate(epoch, args, optimizer)
        print("==> training...")

        time1 = time.time()
        l_loss, l_prob, ab_loss, ab_prob = train(epoch, train_loader, model,
                                                 contrast, criterion_l,
                                                 criterion_ab, optimizer, args)
        time2 = time.time()
        print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))

        # tensorboard logger
        logger.log_value('l_loss', l_loss, epoch)
        logger.log_value('l_prob', l_prob, epoch)
        logger.log_value('ab_loss', ab_loss, epoch)
        logger.log_value('ab_prob', ab_prob, epoch)

        # save model
        if epoch % args.save_freq == 0:
            print('==> Saving...')
            state = {
                'opt': args,
                'model': model.state_dict(),
                'contrast': contrast.state_dict(),
                'optimizer': optimizer.state_dict(),
                'epoch': epoch,
            }
            if args.amp:
                state['amp'] = amp.state_dict()
            save_file = os.path.join(
                args.model_folder,
                'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
            torch.save(state, save_file)
            # help release GPU memory
            del state

        torch.cuda.empty_cache()
예제 #13
0
def main():
    # parse command line
    parser = opts_parser()
    options = parser.parse_args()
    modelfile = options.modelfile
    cfg = config.from_parsed_arguments(options)
    if not options.cuda_device:
        device = torch.device('cpu')
    else:
        device = torch.device('cuda:%d' % options.cuda_device[0])
        torch.cuda.set_device(options.cuda_device[0])
        if options.cuda_sync_mode != 'auto':
            set_cuda_sync_mode(options.cuda_sync_mode)

    # prepare training data generator
    print("Preparing training data feed...")
    train_data = get_dataset(cfg, 'train')
    print_data_info(train_data)
    train_loader = get_dataloader(cfg, train_data, 'train')

    # start training data generation in background
    train_batches = iterate_infinitely(train_loader)
    train_batches = iterate_data(train_batches, device, cfg)

    # if told so, benchmark the creation of a given number of minibatches
    if cfg.get('benchmark_datafeed'):
        print("Benchmark: %d minibatches of %d items..." %
              (cfg['benchmark_datafeed'], cfg['batchsize']))
        import itertools
        t0 = time.time()
        next(
            itertools.islice(train_batches, cfg['benchmark_datafeed'],
                             cfg['benchmark_datafeed']), None)
        t1 = time.time()
        print("%.3gs per minibatch." % ((t1 - t0) / cfg['benchmark_datafeed']))
        return

    # if told so, play back a given key of the training data as audio
    if cfg.get('play_datafeed'):
        import simpleaudio as sa
        for batch in train_batches:
            for wav in batch[cfg['play_datafeed']]:
                if wav.dtype.is_floating_point:
                    wav = (wav * np.iinfo(np.int16).max).short()
                sa.WaveObject(
                    wav.cpu().numpy().T.data,
                    num_channels=wav.shape[0],
                    bytes_per_sample=2,
                    sample_rate=cfg['data.sample_rate']).play().wait_done()

    # prepare validation data generator
    print("Preparing validation data feed...")
    val_data = get_dataset(cfg, 'valid')
    print_data_info(val_data)
    val_loader = get_dataloader(cfg, val_data, 'valid')

    # enable cuDNN auto-tuning if on GPU and all data sizes are constant
    if options.cuda_device and not any(s is None
                                       for data in (train_data, val_data)
                                       for shape in data.shapes.values()
                                       for s in shape):
        torch.backends.cudnn.benchmark = True

    # prepare model
    print("Preparing network...")
    # instantiate neural network
    model = get_model(cfg, train_data.shapes, train_data.dtypes,
                      train_data.num_classes, options.cuda_device)
    print(model)
    print_model_info(model)

    if cfg['train.teacher_model']:
        print("Preparing teacher network...")
        teacher_modelfile = cfg['train.teacher_model']
        teacher_device = torch.device(cfg['train.teacher_model.device']
                                      or device)
        teacher_cfg = dict(cfg)
        teacher_cfg.update(
            config.parse_config_file(
                teacher_modelfile.rsplit('.', 1)[0] + '.vars'))
        teacher_model = get_model(teacher_cfg, train_data.shapes,
                                  train_data.dtypes, train_data.num_classes,
                                  teacher_device.index)
        teacher_model.load_state_dict(
            torch.load(teacher_modelfile, map_location=teacher_device))
        teacher_model.train(False)

    # obtain cost functions
    train_metrics = get_metrics(cfg, 'train')
    val_metrics = get_metrics(cfg, 'valid')
    extract_loss = get_loss_from_metrics(cfg)

    # initialize optimizer
    params = model.parameters()
    if cfg['train.first_params']:
        first_params_count = cfg['train.first_params']
        # if a string, treat as a submodule name, figure out its param count
        if isinstance(first_params_count, str):
            first_params_count = sum(
                len(list(reduce(getattr, name.split('.'), model).parameters()))
                for name in first_params_count.split('+'))
        # advance the `params` iterator, keep the first parameters separately
        params = iter(params)
        first_params = [next(params) for _ in range(first_params_count)]
    optimizer = get_optimizer(cfg, params)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        factor=cfg['train.eta_decay'],
        patience=cfg['train.patience'],
        cooldown=cfg['train.cooldown'],
        verbose=True)

    # initialize mixed-precision training
    if cfg['float16']:
        from apex import amp
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=cfg['float16.opt_level'])
        if cfg['train.teacher_model']:
            teacher_model = amp.initialize(teacher_model,
                                           opt_level=cfg['float16.opt_level'])

    # initialize tensorboard logger, if requested
    if options.logdir:
        from tensorboardize import TensorboardLogger
        logger = TensorboardLogger(options.logdir,
                                   cfg=cfg,
                                   dataloader=val_loader,
                                   model=model,
                                   optimizer=optimizer)
    else:
        logger = None

    # resume training state if possible
    if options.resume and os.path.exists(options.modelfile + '.resume'):
        state = torch.load(options.modelfile + '.resume', map_location=device)
        model.load_state_dict(state.pop('model'))
        optimizer.load_state_dict(state.pop('optimizer'))
        scheduler.load_state_dict(state.pop('scheduler'))
        history = state.pop('history')
        epoch = state.pop('epoch')
        if cfg['float16']:
            amp.load_state_dict(state.pop('amp'))
        if (cfg['train.first_params']
                and epoch > cfg['train.first_params.delay']):
            add_optimizer_params(optimizer, scheduler, first_params,
                                 cfg['train.first_params.eta_scale'])
    else:
        history = {}
        epoch = 0
        # load pretrained weights if requested
        if cfg['model.init_from']:
            model.load_state_dict(torch.load(
                os.path.join(os.path.dirname(__file__),
                             cfg['model.init_from'])),
                                  map_location=device)
        else:
            # run custom initializations
            init_model(model, cfg)
        # log initial state
        if logger is not None:
            logger.log_start()

    # warn about unused configuration keys
    config.warn_unused_variables(
        cfg, ('train.epochs', 'train.epochsize', 'train.min_eta',
              'train.patience_reference', 'loss'))

    # run training loop
    print("Training:")
    for epoch in range(epoch, cfg['train.epochs']):
        # add first_params to optimizer when the delay has passed
        if (cfg['train.first_params']
                and cfg['train.first_params.delay'] == epoch):
            add_optimizer_params(optimizer, scheduler, first_params,
                                 cfg['train.first_params.eta_scale'])
            if cfg['debug']:
                print(
                    'Training first %d parameters with learning rate '
                    'scaled by %f.' %
                    (first_params_count, cfg['train.first_params.eta_scale']))
        # training pass
        model.train(True)
        if cfg['debug']:
            torch.autograd.set_detect_anomaly(True)
        train_errors = AverageMetrics()
        nans_in_a_row = 0
        for _ in tqdm.trange(cfg['train.epochsize'],
                             desc='Epoch %d/%d' %
                             (epoch + 1, cfg['train.epochs']),
                             ascii=bool(cfg['tqdm.ascii'])):
            # grab the next minibatch
            batch = next(train_batches)
            # reset gradients
            optimizer.zero_grad()
            # compute output
            preds = model(batch)
            # compute born-again output, if needed
            if cfg['train.teacher_model']:
                teacher_batch = copy_to_device(batch, teacher_device)
                with torch.no_grad():
                    teacher_preds = teacher_model(teacher_batch)
                teacher_preds = copy_to_device(teacher_preds, device)
                batch.update(
                    ('teacher.' + k, v) for k, v in teacher_preds.items())
            # compute training metrics and loss
            metrics = OrderedDict(
                (k, fn(preds, batch)) for k, fn in train_metrics.items())
            loss = extract_loss(metrics)
            # bail out if Not a Number
            if not np.isfinite(loss.item()):
                if cfg['debug']:
                    raise RuntimeError('Training error is NaN!')
                nans_in_a_row += 1
                if nans_in_a_row < 5:
                    print('Training error is NaN! Skipping step.')
                    continue
                else:
                    print('Training error is NaN! Stopping training.')
                    return 1
            else:
                nans_in_a_row = 0
            train_errors += metrics
            train_errors += {'loss': loss.item()}
            # backprop and update
            if cfg['float16']:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            optimizer.step()
        print_metrics('Train', train_errors.aggregate())
        del batch, preds, loss

        # validation pass
        model.train(False)
        val_errors = AverageMetrics()
        for batch in iterate_data(iter(val_loader), device, cfg):
            with torch.no_grad():
                preds = model(batch)
                metrics = {
                    k: fn(preds, batch)
                    for k, fn in val_metrics.items()
                }
            val_loss = float(extract_loss(metrics).item())
            val_errors += metrics
            val_errors += {'loss': val_loss}
        print_metrics('Valid', val_errors.aggregate())
        del batch, preds, val_loss

        log_metrics(train_errors.aggregate(), val_errors.aggregate(), history,
                    modelfile)
        if logger is not None:
            logger.log_epoch(epoch, {k: v[-1] for k, v in history.items()})

        # learning rate update
        reference = history[cfg['train.patience_reference'].lstrip('-')][-1]
        if hasattr(reference, 'mean'):
            reference = reference.mean()
        if cfg['train.patience_reference'].startswith('-'):
            reference *= -1
        scheduler.step(reference)
        if optimizer.param_groups[0]['lr'] < cfg['train.min_eta']:
            print('Learning rate fell below threshold. Stopping training.')
            break

        # save training state to resume file
        resume_state = dict(model=model.state_dict(),
                            optimizer=optimizer.state_dict(),
                            scheduler=scheduler.state_dict(),
                            epoch=epoch + 1,
                            history=history)
        if cfg['float16']:
            resume_state['amp'] = amp.state_dict()
        torch.save(resume_state, options.modelfile + '.resume')
        del resume_state

    # save final network and the configuration used
    print("Saving final model")
    save_model(modelfile, model, cfg)

    # delete resume file if any
    if os.path.exists(options.modelfile + '.resume'):
        os.remove(options.modelfile + '.resume')

    # log the final state
    if logger is not None:
        logger.log_end(history)
예제 #14
0
def main(args):  # pylint: disable=redefined-outer-name
    # pylint: disable=global-variable-undefined
    global meta_data_train, meta_data_eval, symbols, phonemes
    # Audio processor
    ap = AudioProcessor(**c.audio)
    if 'characters' in c.keys():
        symbols, phonemes = make_symbols(**c.characters)

    # DISTRUBUTED
    if num_gpus > 1:
        init_distributed(args.rank, num_gpus, args.group_id,
                         c.distributed["backend"], c.distributed["url"])
    num_chars = len(phonemes) if c.use_phonemes else len(symbols)

    # load data instances
    meta_data_train, meta_data_eval = load_meta_data(c.datasets)

    # set the portion of the data used for training
    if 'train_portion' in c.keys():
        meta_data_train = meta_data_train[:int(
            len(meta_data_train) * c.train_portion)]
    if 'eval_portion' in c.keys():
        meta_data_eval = meta_data_eval[:int(
            len(meta_data_eval) * c.eval_portion)]

    # parse speakers
    if c.use_speaker_embedding:
        speakers = get_speakers(meta_data_train)
        if args.restore_path:
            if c.use_external_speaker_embedding_file:  # if restore checkpoint and use External Embedding file
                prev_out_path = os.path.dirname(args.restore_path)
                speaker_mapping = load_speaker_mapping(prev_out_path)
                if not speaker_mapping:
                    print(
                        "WARNING: speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file"
                    )
                    speaker_mapping = load_speaker_mapping(
                        c.external_speaker_embedding_file)
                    if not speaker_mapping:
                        raise RuntimeError(
                            "You must copy the file speakers.json to restore_path, or set a valid file in CONFIG.external_speaker_embedding_file"
                        )
                speaker_embedding_dim = len(speaker_mapping[list(
                    speaker_mapping.keys())[0]]['embedding'])
            elif not c.use_external_speaker_embedding_file:  # if restore checkpoint and don't use External Embedding file
                prev_out_path = os.path.dirname(args.restore_path)
                speaker_mapping = load_speaker_mapping(prev_out_path)
                speaker_embedding_dim = None
                assert all([speaker in speaker_mapping
                            for speaker in speakers]), "As of now you, you cannot " \
                                                    "introduce new speakers to " \
                                                    "a previously trained model."
        elif c.use_external_speaker_embedding_file and c.external_speaker_embedding_file:  # if start new train using External Embedding file
            speaker_mapping = load_speaker_mapping(
                c.external_speaker_embedding_file)
            speaker_embedding_dim = len(speaker_mapping[list(
                speaker_mapping.keys())[0]]['embedding'])
        elif c.use_external_speaker_embedding_file and not c.external_speaker_embedding_file:  # if start new train using External Embedding file and don't pass external embedding file
            raise "use_external_speaker_embedding_file is True, so you need pass a external speaker embedding file, run GE2E-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb or AngularPrototypical-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb notebook in notebooks/ folder"
        else:  # if start new train and don't use External Embedding file
            speaker_mapping = {name: i for i, name in enumerate(speakers)}
            speaker_embedding_dim = None
        save_speaker_mapping(OUT_PATH, speaker_mapping)
        num_speakers = len(speaker_mapping)
        print("Training with {} speakers: {}".format(num_speakers,
                                                     ", ".join(speakers)))
    else:
        num_speakers = 0
        speaker_embedding_dim = None
        speaker_mapping = None

    model = setup_model(num_chars, num_speakers, c, speaker_embedding_dim)

    params = set_weight_decay(model, c.wd)
    optimizer = RAdam(params, lr=c.lr, weight_decay=0)
    if c.stopnet and c.separate_stopnet:
        optimizer_st = RAdam(model.decoder.stopnet.parameters(),
                             lr=c.lr,
                             weight_decay=0)
    else:
        optimizer_st = None

    if c.apex_amp_level == "O1":
        # pylint: disable=import-outside-toplevel
        from apex import amp
        model.cuda()
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=c.apex_amp_level)
    else:
        amp = None

    # setup criterion
    criterion = TacotronLoss(c, stopnet_pos_weight=10.0, ga_sigma=0.4)

    if args.restore_path:
        checkpoint = torch.load(args.restore_path, map_location='cpu')
        try:
            # TODO: fix optimizer init, model.cuda() needs to be called before
            # optimizer restore
            # optimizer.load_state_dict(checkpoint['optimizer'])
            if c.reinit_layers:
                raise RuntimeError
            model.load_state_dict(checkpoint['model'])
        except KeyError:
            print(" > Partial model initialization.")
            model_dict = model.state_dict()
            model_dict = set_init_dict(model_dict, checkpoint['model'], c)
            # torch.save(model_dict, os.path.join(OUT_PATH, 'state_dict.pt'))
            # print("State Dict saved for debug in: ", os.path.join(OUT_PATH, 'state_dict.pt'))
            model.load_state_dict(model_dict)
            del model_dict

        if amp and 'amp' in checkpoint:
            amp.load_state_dict(checkpoint['amp'])

        for group in optimizer.param_groups:
            group['lr'] = c.lr
        print(" > Model restored from step %d" % checkpoint['step'],
              flush=True)
        args.restore_step = checkpoint['step']
    else:
        args.restore_step = 0

    if use_cuda:
        model.cuda()
        criterion.cuda()

    # DISTRUBUTED
    if num_gpus > 1:
        model = apply_gradient_allreduce(model)

    if c.noam_schedule:
        scheduler = NoamLR(optimizer,
                           warmup_steps=c.warmup_steps,
                           last_epoch=args.restore_step - 1)
    else:
        scheduler = None

    num_params = count_parameters(model)
    print("\n > Model has {} parameters".format(num_params), flush=True)

    if 'best_loss' not in locals():
        best_loss = float('inf')

    global_step = args.restore_step
    for epoch in range(0, c.epochs):
        c_logger.print_epoch_start(epoch, c.epochs)
        # set gradual training
        if c.gradual_training is not None:
            r, c.batch_size = gradual_training_scheduler(global_step, c)
            c.r = r
            model.decoder.set_r(r)
            if c.bidirectional_decoder:
                model.decoder_backward.set_r(r)
            print("\n > Number of output frames:", model.decoder.r)
        train_avg_loss_dict, global_step = train(model, criterion, optimizer,
                                                 optimizer_st, scheduler, ap,
                                                 global_step, epoch, amp,
                                                 speaker_mapping)
        if epoch % 10 == 0:
            eval_avg_loss_dict = evaluate(model, criterion, ap, global_step,
                                          epoch, speaker_mapping)
            c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
            target_loss = train_avg_loss_dict['avg_postnet_loss']
            if c.run_eval:
                target_loss = eval_avg_loss_dict['avg_postnet_loss']
            best_loss = save_best_model(
                target_loss,
                best_loss,
                model,
                optimizer,
                global_step,
                epoch,
                c.r,
                OUT_PATH,
                amp_state_dict=amp.state_dict() if amp else None)
예제 #15
0
def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu
    args.rank = args.rank * ngpus_per_node + gpu
    # model name for checkpoint
    args.model = "{}-{}".format(
        args.arch,
        os.path.splitext(os.path.basename(args.config_file))[0])
    if args.gpu == 0:
        print('model:', args.model)
    print('rank: {} / {}'.format(args.rank, args.world_size))
    dist.init_process_group(backend=args.dist_backend,
                            init_method=args.dist_url,
                            world_size=args.world_size,
                            rank=args.rank)
    torch.cuda.set_device(args.gpu)
    # init the args
    global best_pred, acclist_train, acclist_val

    if args.gpu == 0:
        print(args)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    cudnn.benchmark = True
    # init dataloader
    transform_train, transform_val = encoding.transforms.get_transform(
        args.dataset, args.base_size, args.crop_size)
    if args.auto_policy is not None:
        print(f'Using auto_policy: {args.auto_policy}')
        from augment import Augmentation
        auto_policy = Augmentation(at.load(args.auto_policy))
        transform_train.transforms.insert(0, auto_policy)

    trainset = encoding.datasets.get_dataset(args.dataset,
                                             root=args.data_dir,
                                             transform=transform_train,
                                             train=True,
                                             download=True)
    valset = encoding.datasets.get_dataset(args.dataset,
                                           root=args.data_dir,
                                           transform=transform_val,
                                           train=False,
                                           download=True)

    train_sampler = torch.utils.data.distributed.DistributedSampler(trainset)
    train_loader = torch.utils.data.DataLoader(trainset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_sampler = torch.utils.data.distributed.DistributedSampler(
        valset, shuffle=False)
    val_loader = torch.utils.data.DataLoader(valset,
                                             batch_size=args.test_batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True,
                                             sampler=val_sampler)

    # init the model
    arch = importlib.import_module('generator.' + args.arch)
    model = arch.config_network(args.config_file)
    if args.gpu == 0:
        print(model)

    if args.mixup > 0:
        train_loader = MixUpWrapper(args.mixup, 1000, train_loader, args.gpu)
        criterion = NLLMultiLabelSmooth(args.label_smoothing)
    elif args.label_smoothing > 0.0:
        criterion = LabelSmoothing(args.label_smoothing)
    else:
        criterion = nn.CrossEntropyLoss()

    model.cuda(args.gpu)
    criterion.cuda(args.gpu)
    # criterion and optimizer
    if args.no_bn_wd:
        parameters = model.named_parameters()
        param_dict = {}
        for k, v in parameters:
            param_dict[k] = v
        bn_params = [
            v for n, v in param_dict.items() if ('bn' in n or 'bias' in n)
        ]
        rest_params = [
            v for n, v in param_dict.items() if not ('bn' in n or 'bias' in n)
        ]
        if args.gpu == 0:
            print(" Weight decay NOT applied to BN parameters ")
            print(
                f'len(parameters): {len(list(model.parameters()))} = {len(bn_params)} + {len(rest_params)}'
            )
        optimizer = torch.optim.SGD([{
            'params': bn_params,
            'weight_decay': 0
        }, {
            'params': rest_params,
            'weight_decay': args.weight_decay
        }],
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    else:
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    if args.amp:
        #optimizer = amp_handle.wrap_optimizer(optimizer)
        model, optimizer = amp.initialize(model, optimizer, opt_level='O2')
        #from apex import amp
        DDP = apex.parallel.DistributedDataParallel
        model = DDP(model, delay_allreduce=True)
    else:
        DDP = DistributedDataParallel
        model = DDP(model, device_ids=[args.gpu])

    # check point
    if args.resume is not None:
        if os.path.isfile(args.resume):
            if args.gpu == 0:
                print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint[
                'epoch'] + 1 if args.start_epoch == 0 else args.start_epoch
            best_pred = checkpoint['best_pred']
            acclist_train = checkpoint['acclist_train']
            acclist_val = checkpoint['acclist_val']
            model.module.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            if args.amp:
                amp.load_state_dict(checkpoint['amp'])
            if args.gpu == 0:
                print("=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']))
        else:
            raise RuntimeError ("=> no resume checkpoint found at '{}'".\
                format(args.resume))
    scheduler = LR_Scheduler(args.lr_scheduler,
                             base_lr=args.lr,
                             num_epochs=args.epochs,
                             iters_per_epoch=len(train_loader),
                             warmup_epochs=args.warmup_epochs)

    def train(epoch):
        train_sampler.set_epoch(epoch)
        model.train()
        losses = AverageMeter()
        top1 = AverageMeter()
        global best_pred, acclist_train
        tic = time.time()
        for batch_idx, (data, target) in enumerate(train_loader):
            scheduler(optimizer, batch_idx, epoch, best_pred)
            if not args.mixup:
                data, target = data.cuda(args.gpu), target.cuda(args.gpu)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            if args.amp:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            optimizer.step()

            if not args.mixup:
                acc1 = accuracy(output, target, topk=(1, ))
                top1.update(acc1[0], data.size(0))

            losses.update(loss.item(), data.size(0))
            if batch_idx % 100 == 0 and args.gpu == 0:
                iter_per_sec = 100.0 / (time.time() -
                                        tic) if batch_idx != 0 else 1.0 / (
                                            time.time() - tic)
                tic = time.time()
                if args.mixup:
                    #print('Batch: %d| Loss: %.3f'%(batch_idx, losses.avg))
                    print('Epoch: {}, Iter: {}, Speed: {:.3f} iter/sec, Train loss: {:.3f}'. \
                          format(epoch, batch_idx, iter_per_sec, losses.avg.item()))
                else:
                    #print('Batch: %d| Loss: %.3f | Top1: %.3f'%(batch_idx, losses.avg, top1.avg))
                    print('Epoch: {}, Iter: {}, Speed: {:.3f} iter/sec, Top1: {:.3f}'. \
                          format(epoch, batch_idx, iter_per_sec, top1.avg.item()))

        acclist_train += [top1.avg]

    def validate(epoch):
        model.eval()
        top1 = AverageMeter()
        top5 = AverageMeter()
        global best_pred, acclist_train, acclist_val
        is_best = False
        for batch_idx, (data, target) in enumerate(val_loader):
            data, target = data.cuda(args.gpu), target.cuda(args.gpu)
            with torch.no_grad():
                output = model(data)
                acc1, acc5 = accuracy(output, target, topk=(1, 5))
                top1.update(acc1[0], data.size(0))
                top5.update(acc5[0], data.size(0))

        # sum all
        sum1, cnt1, sum5, cnt5 = torch_dist_sum(args.gpu, top1.sum, top1.count,
                                                top5.sum, top5.count)

        if args.eval:
            if args.gpu == 0:
                top1_acc = sum(sum1) / sum(cnt1)
                top5_acc = sum(sum5) / sum(cnt5)
                print('Validation: Top1: %.3f | Top5: %.3f' %
                      (top1_acc, top5_acc))
            return

        if args.gpu == 0:
            top1_acc = sum(sum1) / sum(cnt1)
            top5_acc = sum(sum5) / sum(cnt5)
            print('Validation: Top1: %.3f | Top5: %.3f' % (top1_acc, top5_acc))

            # save checkpoint
            acclist_val += [top1_acc]
            if top1_acc > best_pred:
                best_pred = top1_acc
                is_best = True
            state_dict = {
                'epoch': epoch,
                'state_dict': model.module.state_dict(),
                'optimizer': optimizer.state_dict(),
                'best_pred': best_pred,
                'acclist_train': acclist_train,
                'acclist_val': acclist_val,
            }
            if args.amp:
                state_dict['amp'] = amp.state_dict()
            encoding.utils.save_checkpoint(state_dict,
                                           args=args,
                                           is_best=is_best)

    if args.export:
        if args.gpu == 0:
            torch.save(model.module.state_dict(), args.export + '.pth')
        return

    if args.eval:
        validate(args.start_epoch)
        return

    for epoch in range(args.start_epoch, args.epochs):
        tic = time.time()
        train(epoch)
        if epoch % 10 == 0:  # or epoch == args.epochs-1:
            validate(epoch)
        elapsed = time.time() - tic
        if args.gpu == 0:
            print(f'Epoch: {epoch}, Time cost: {elapsed}')

    if args.gpu == 0:
        encoding.utils.save_checkpoint(
            {
                'epoch': args.epochs - 1,
                'state_dict': model.module.state_dict(),
                'optimizer': optimizer.state_dict(),
                'best_pred': best_pred,
                'acclist_train': acclist_train,
                'acclist_val': acclist_val,
            },
            args=args,
            is_best=False)
예제 #16
0
 def load_state_dict(self, state_dict):
     if 'load_state_dict' in amp.__dict__:
         amp.load_state_dict(state_dict)
예제 #17
0
def train(cfg):

    # 결과를 저장하기 위한 txt파일 초기화
    with open(
            "/home/jhjeong/jiho_deep/deepspeech.pytorch/jiho_result/result.txt",
            "w") as ff:
        ff.write("학습 시작! \n")

    # Set seeds for determinism
    torch.manual_seed(cfg.training.seed)
    torch.cuda.manual_seed_all(cfg.training.seed)
    np.random.seed(cfg.training.seed)
    random.seed(cfg.training.seed)

    main_proc = True
    device = torch.device("cpu" if cfg.training.no_cuda else "cuda")

    is_distributed = os.environ.get(
        "LOCAL_RANK")  # If local rank exists, distributed env

    if is_distributed:
        # when using NCCL, on failures, surviving nodes will deadlock on NCCL ops
        # because NCCL uses a spin-lock on the device. Set this env var and
        # to enable a watchdog thread that will destroy stale NCCL communicators
        os.environ["NCCL_BLOCKING_WAIT"] = "1"

        device_id = int(os.environ["LOCAL_RANK"])
        torch.cuda.set_device(device_id)
        print(f"Setting CUDA Device to {device_id}")

        dist.init_process_group(backend=cfg.training.dist_backend)
        main_proc = device_id == 0  # Main process handles saving of models and reporting

    checkpoint_handler = CheckpointHandler(
        save_folder=to_absolute_path(cfg.checkpointing.save_folder),
        best_val_model_name=cfg.checkpointing.best_val_model_name,
        checkpoint_per_iteration=cfg.checkpointing.checkpoint_per_iteration,
        save_n_recent_models=cfg.checkpointing.save_n_recent_models)

    #visdom 사용할건지 tensorboard 사용할건지
    if main_proc and cfg.visualization.visdom:
        visdom_logger = VisdomLogger(id=cfg.visualization.id,
                                     num_epochs=cfg.training.epochs)
    if main_proc and cfg.visualization.tensorboard:
        tensorboard_logger = TensorBoardLogger(
            id=cfg.visualization.id,
            log_dir=to_absolute_path(cfg.visualization.log_dir),
            log_params=cfg.visualization.log_params)

    if cfg.checkpointing.load_auto_checkpoint:
        latest_checkpoint = checkpoint_handler.find_latest_checkpoint()
        if latest_checkpoint:
            cfg.checkpointing.continue_from = latest_checkpoint

    # 여기서 부터
    if cfg.checkpointing.continue_from:  # Starting from previous model
        state = TrainingState.load_state(
            state_path=to_absolute_path(cfg.checkpointing.continue_from))
        model = state.model
        if cfg.training.finetune:
            state.init_finetune_states(cfg.training.epochs)

        if main_proc and cfg.visualization.visdom:  # Add previous scores to visdom graph
            visdom_logger.load_previous_values(state.epoch, state.results)
        if main_proc and cfg.visualization.tensorboard:  # Previous scores to tensorboard logs
            tensorboard_logger.load_previous_values(state.epoch, state.results)
    else:
        # Initialise new model training
        with open(to_absolute_path(cfg.data.labels_path)) as label_file:
            labels = json.load(label_file)  # label(a,b,c ...)

        audio_conf = dict(sample_rate=cfg.data.sample_rate,
                          window_size=cfg.data.window_size,
                          window_stride=cfg.data.window_stride,
                          window=cfg.data.window)
        if cfg.augmentation.noise_dir:
            audio_conf += dict(noise_dir=to_absolute_path(
                cfg.augmentation.noise_dir),
                               noise_prob=cfg.augmentation.noise_prob,
                               noise_levels=(cfg.augmentation.noise_min,
                                             cfg.augmentation.noise_max))

        rnn_type = cfg.model.rnn_type.lower()
        assert rnn_type in supported_rnns, "rnn_type should be either lstm, rnn or gru"

        # DeepSpeech 모델을 생성
        model = DeepSpeech(rnn_hidden_size=cfg.model.hidden_size,
                           nb_layers=cfg.model.hidden_layers,
                           labels=labels,
                           rnn_type=supported_rnns[rnn_type],
                           audio_conf=audio_conf,
                           bidirectional=cfg.model.bidirectional)

        state = TrainingState(model=model)
        state.init_results_tracking(epochs=cfg.training.epochs)

    # Data setup
    evaluation_decoder = GreedyDecoder(
        model.labels)  # Decoder used for validation

    # Data path 정리
    train_dataset = SpectrogramDataset(
        audio_conf=model.audio_conf,
        manifest_filepath=to_absolute_path(cfg.data.train_manifest),
        labels=model.labels,
        normalize=True,
        speed_volume_perturb=cfg.augmentation.speed_volume_perturb,
        spec_augment=cfg.augmentation.spec_augment)

    test_dataset = SpectrogramDataset(audio_conf=model.audio_conf,
                                      manifest_filepath=to_absolute_path(
                                          cfg.data.val_manifest),
                                      labels=model.labels,
                                      normalize=True,
                                      speed_volume_perturb=False,
                                      spec_augment=False)

    if not is_distributed:
        train_sampler = DSRandomSampler(dataset=train_dataset,
                                        batch_size=cfg.data.batch_size,
                                        start_index=state.training_step)
    else:
        train_sampler = DSElasticDistributedSampler(
            dataset=train_dataset,
            batch_size=cfg.data.batch_size,
            start_index=state.training_step)

    # data load 하는 부분
    train_loader = AudioDataLoader(dataset=train_dataset,
                                   num_workers=cfg.data.num_workers,
                                   batch_sampler=train_sampler)
    test_loader = AudioDataLoader(dataset=test_dataset,
                                  num_workers=cfg.data.num_workers,
                                  batch_size=cfg.data.batch_size)

    model = model.to(device)
    parameters = model.parameters()

    if cfg.optimizer.adam:
        optimizer = torch.optim.AdamW(parameters,
                                      lr=cfg.optimizer.learning_rate,
                                      betas=cfg.optimizer.betas,
                                      eps=cfg.optimizer.eps,
                                      weight_decay=cfg.optimizer.weight_decay)
    else:
        optimizer = torch.optim.SGD(parameters,
                                    lr=cfg.optimizer.learning_rate,
                                    momentum=cfg.optimizer.momentum,
                                    nesterov=True,
                                    weight_decay=cfg.optimizer.weight_decay)

    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      opt_level=cfg.apex.opt_level,
                                      loss_scale=cfg.apex.loss_scale)

    if state.optim_state is not None:
        optimizer.load_state_dict(state.optim_state)
        amp.load_state_dict(state.amp_state)

    # Track states for optimizer/amp
    state.track_optim_state(optimizer)
    state.track_amp_state(amp)

    if is_distributed:
        model = DistributedDataParallel(model, device_ids=[device_id])

    print(model)
    print("Number of parameters: %d" % DeepSpeech.get_param_size(model))

    criterion = CTCLoss()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()

    for epoch in range(state.epoch, cfg.training.epochs):
        model.train()
        end = time.time()
        start_epoch_time = time.time()
        state.set_epoch(epoch=epoch)
        train_sampler.set_epoch(epoch=epoch)
        train_sampler.reset_training_step(training_step=state.training_step)

        #train data있는거 가져다 사용하겠다.
        for i, (data) in enumerate(train_loader, start=state.training_step):
            state.set_training_step(training_step=i)
            inputs, targets, input_percentages, target_sizes = data
            input_sizes = input_percentages.mul_(int(inputs.size(3))).int()

            # measure data loading time
            data_time.update(time.time() - end)
            inputs = inputs.to(device)

            out, output_sizes = model(inputs, input_sizes)
            out = out.transpose(0, 1)  # TxNxH

            float_out = out.float()  # ensure float32 for loss
            loss = criterion(float_out, targets, output_sizes,
                             target_sizes).to(device)
            loss = loss / inputs.size(0)  # average the loss by minibatch
            loss_value = loss.item()

            # Check to ensure valid loss was calculated
            valid_loss, error = check_loss(loss, loss_value)
            if valid_loss:
                optimizer.zero_grad()

                # compute gradient
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                               cfg.optimizer.max_norm)
                optimizer.step()
            else:
                print(error)
                print('Skipping grad update')
                loss_value = 0

            state.avg_loss += loss_value
            losses.update(loss_value, inputs.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
                      (epoch + 1), (i + 1),
                      len(train_loader),
                      batch_time=batch_time,
                      data_time=data_time,
                      loss=losses))

            if main_proc and cfg.checkpointing.checkpoint_per_iteration:
                checkpoint_handler.save_iter_checkpoint_model(epoch=epoch,
                                                              i=i,
                                                              state=state)
            del loss, out, float_out

        state.avg_loss /= len(train_dataset)

        epoch_time = time.time() - start_epoch_time
        print('Training Summary Epoch: [{0}]\t'
              'Time taken (s): {epoch_time:.0f}\t'
              'Average Loss {loss:.3f}\t'.format(epoch + 1,
                                                 epoch_time=epoch_time,
                                                 loss=state.avg_loss))

        with open(
                "/home/jhjeong/jiho_deep/deepspeech.pytorch/jiho_result/result.txt",
                "a") as ff:
            ff.write("\n")
            ff.write("train -> ")
            ff.write("epoch : ")
            ff.write(str(epoch + 1))
            ff.write(" loss : ")
            ff.write(str(state.avg_loss))
            ff.write("\n")

        with torch.no_grad():
            wer, cer, output_data = evaluate(test_loader=test_loader,
                                             device=device,
                                             model=model,
                                             decoder=evaluation_decoder,
                                             target_decoder=evaluation_decoder)

        state.add_results(epoch=epoch,
                          loss_result=state.avg_loss,
                          wer_result=wer,
                          cer_result=cer)

        print('Validation Summary Epoch: [{0}]\t'
              'Average CER {cer:.3f}\t'.format(epoch + 1, cer=cer))

        with open(
                "/home/jhjeong/jiho_deep/deepspeech.pytorch/jiho_result/result.txt",
                "a") as ff:
            ff.write("\n")
            ff.write("val -> ")
            ff.write("epoch : ")
            ff.write(str(epoch + 1))
            ff.write(" cer : ")
            ff.write(str(cer))
            ff.write("\n")

        # 텐서보드에 업데이트함
        if main_proc and cfg.visualization.visdom:
            visdom_logger.update(epoch, state.result_state)
        if main_proc and cfg.visualization.tensorboard:
            tensorboard_logger.update(epoch, state.result_state,
                                      model.named_parameters())

        if main_proc and cfg.checkpointing.checkpoint:  # Save epoch checkpoint
            checkpoint_handler.save_checkpoint_model(epoch=epoch, state=state)
        # anneal lr
        for g in optimizer.param_groups:
            g['lr'] = g['lr'] / cfg.optimizer.learning_anneal
        print('Learning rate annealed to: {lr:.6f}'.format(lr=g['lr']))

        if main_proc and (state.best_wer is None or state.best_wer > wer):
            checkpoint_handler.save_best_model(epoch=epoch, state=state)
            state.set_best_wer(wer)
            state.reset_avg_loss()
        state.reset_training_step()  # Reset training step for next epoch
예제 #18
0
    def init_from_checkpoint(self,
                             path: str,
                             reset_best_ckpt: bool = False,
                             reset_scheduler: bool = False,
                             reset_optimizer: bool = False,
                             reset_iter_state: bool = False) -> None:
        """
        Initialize the trainer from a given checkpoint file.

        This checkpoint file contains not only model parameters, but also
        scheduler and optimizer states, see `self._save_checkpoint`.

        :param path: path to checkpoint
        :param reset_best_ckpt: reset tracking of the best checkpoint,
                                use for domain adaptation with a new dev
                                set or when using a new metric for fine-tuning.
        :param reset_scheduler: reset the learning rate scheduler, and do not
                                use the one stored in the checkpoint.
        :param reset_optimizer: reset the optimizer, and do not use the one
                                stored in the checkpoint.
        :param reset_iter_state: reset the sampler's internal state and do not
                                use the one stored in the checkpoint.
        """
        logger.info("Loading model from %s", path)
        model_checkpoint = load_checkpoint(path=path, use_cuda=self.use_cuda)

        # restore model and optimizer parameters
        self.model.load_state_dict(model_checkpoint["model_state"])

        if not reset_optimizer:
            self.optimizer.load_state_dict(model_checkpoint["optimizer_state"])
        else:
            logger.info("Reset optimizer.")

        if not reset_scheduler:
            if model_checkpoint["scheduler_state"] is not None and \
                    self.scheduler is not None:
                self.scheduler.load_state_dict(
                    model_checkpoint["scheduler_state"])
        else:
            logger.info("Reset scheduler.")

        # restore counts
        self.stats.steps = model_checkpoint["steps"]
        self.stats.total_tokens = model_checkpoint["total_tokens"]

        if not reset_best_ckpt:
            self.stats.best_ckpt_score = model_checkpoint["best_ckpt_score"]
            self.stats.best_ckpt_iter = model_checkpoint["best_ckpt_iteration"]
        else:
            logger.info("Reset tracking of the best checkpoint.")

        if (not reset_iter_state and model_checkpoint.get(
                'train_iter_state', None) is not None):
            self.train_iter_state = model_checkpoint["train_iter_state"]

        # move parameters to cuda
        if self.use_cuda:
            self.model.to(self.device)

        # fp16
        if self.fp16 and model_checkpoint.get("amp_state", None) is not None:
            amp.load_state_dict(model_checkpoint['amp_state'])
예제 #19
0
파일: train.py 프로젝트: CookiePPP/codedump
def load_checkpoint(checkpoint_path, model, optimizer):
    assert os.path.isfile(checkpoint_path)
    print("Loading checkpoint '{}'".format(checkpoint_path))
    checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
    
    state_dict = {k.replace("encoder_speaker_embedding.weight","encoder.encoder_speaker_embedding.weight"): v for k,v in torch.load(checkpoint_path)['state_dict'].items()}
    model.load_state_dict(state_dict) # tmp for updating old models
    
    #model.load_state_dict(checkpoint_dict['state_dict']) # original
    
    if 'optimizer' in checkpoint_dict.keys(): optimizer.load_state_dict(checkpoint_dict['optimizer'])
    if 'amp' in checkpoint_dict.keys(): amp.load_state_dict(checkpoint_dict['amp'])
    if 'learning_rate' in checkpoint_dict.keys(): learning_rate = checkpoint_dict['learning_rate']
    #if 'hparams' in checkpoint_dict.keys(): hparams = checkpoint_dict['hparams']
    if 'best_validation_loss' in checkpoint_dict.keys(): best_validation_loss = checkpoint_dict['best_validation_loss']
    if 'average_loss' in checkpoint_dict.keys(): average_loss = checkpoint_dict['average_loss']
    if (start_from_checkpoints_from_zero):
        iteration = 0
    else:
    iteration = checkpoint_dict['iteration']
    print("Loaded checkpoint '{}' from iteration {}" .format(
        checkpoint_path, iteration))
    return model, optimizer, learning_rate, iteration, best_validation_loss


def save_checkpoint(model, optimizer, learning_rate, iteration, hparams, best_validation_loss, average_loss, speaker_id_lookup, filepath):
    from utils import load_filepaths_and_text
    tqdm.write("Saving model and optimizer state at iteration {} to {}".format(
        iteration, filepath))
    
    # get speaker names to ID
    speakerlist = load_filepaths_and_text(hparams.speakerlist)
    speaker_name_lookup = {x[2]: speaker_id_lookup[x[3]] for x in speakerlist}
    
    torch.save({'iteration': iteration,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'learning_rate': learning_rate,
                #'amp': amp.state_dict(),
                'hparams': hparams,
                'speaker_id_lookup': speaker_id_lookup,
                'speaker_name_lookup': speaker_name_lookup,
                'best_validation_loss': best_validation_loss,
                'average_loss': average_loss}, filepath)
    tqdm.write("Saving Complete")


def validate(model, criterion, valset, iteration, batch_size, n_gpus,
             collate_fn, logger, distributed_run, rank, val_teacher_force_till, val_p_teacher_forcing, teacher_force=1):
    """Handles all the validation scoring and printing"""
    model.eval()
    with torch.no_grad():
        val_sampler = DistributedSampler(valset) if distributed_run else None
        val_loader = DataLoader(valset, sampler=val_sampler, num_workers=1,
                                shuffle=False, batch_size=batch_size,
                                pin_memory=False, drop_last=True, collate_fn=collate_fn)
        if teacher_force == 1:
            val_teacher_force_till = 0
            val_p_teacher_forcing = 1.0
        elif teacher_force == 2:
            val_teacher_force_till = 0
            val_p_teacher_forcing = 0.0
        val_loss = 0.0
        diagonality = torch.zeros(1)
        avg_prob = torch.zeros(1)
        for i, batch in tqdm(enumerate(val_loader), desc="Validation", total=len(val_loader), smoothing=0): # i = index, batch = stuff in array[i]
            x, y = model.parse_batch(batch)
            y_pred = model(x, teacher_force_till=val_teacher_force_till, p_teacher_forcing=val_p_teacher_forcing)
            rate, prob = alignment_metric(x, y_pred)
            diagonality += rate
            avg_prob += prob
            loss, gate_loss = criterion(y_pred, y)
            if distributed_run:
                reduced_val_loss = reduce_tensor(loss.data, n_gpus).item()
            else:
                reduced_val_loss = loss.item()
            val_loss += reduced_val_loss
            # end forloop
        val_loss = val_loss / (i + 1)
        diagonality = (diagonality / (i + 1)).item()
        avg_prob = (avg_prob / (i + 1)).item()
        # end torch.no_grad()
    model.train()
    if rank == 0:
        tqdm.write("Validation loss {}: {:9f}  Average Max Attention: {:9f}".format(iteration, val_loss, avg_prob))
        #logger.log_validation(val_loss, model, y, y_pred, iteration)
        if True:#iteration != 0:
            if teacher_force == 1:
                logger.log_teacher_forced_validation(val_loss, model, y, y_pred, iteration, val_teacher_force_till, val_p_teacher_forcing, diagonality, avg_prob)
            elif teacher_force == 2:
                logger.log_infer(val_loss, model, y, y_pred, iteration, val_teacher_force_till, val_p_teacher_forcing, diagonality, avg_prob)
            else:
                logger.log_validation(val_loss, model, y, y_pred, iteration, val_teacher_force_till, val_p_teacher_forcing, diagonality, avg_prob)
    return val_loss


def calculate_global_mean(data_loader, global_mean_npy, hparams):
    if global_mean_npy and os.path.exists(global_mean_npy):
        global_mean = np.load(global_mean_npy)
        return to_gpu(torch.tensor(global_mean).half()) if hparams.fp16_run else to_gpu(torch.tensor(global_mean).float())
    sums = []
    frames = []
    print('calculating global mean...')
    for i, batch in tqdm(enumerate(data_loader), total=len(data_loader), smoothing=0.001):
        text_padded, input_lengths, mel_padded, gate_padded,\
            output_lengths, speaker_ids, torchmoji_hidden, preserve_decoder_states = batch
        # padded values are 0.
        sums.append(mel_padded.double().sum(dim=(0, 2)))
        frames.append(output_lengths.double().sum())
    global_mean = sum(sums) / sum(frames)
    global_mean = to_gpu(global_mean.half()) if hparams.fp16_run else to_gpu(global_mean.float())
    if global_mean_npy:
        np.save(global_mean_npy, global_mean.cpu().numpy())
    return global_mean


def train(output_directory, log_directory, checkpoint_path, warm_start, warm_start_force, n_gpus,
          rank, group_name, hparams):
    """Training and validation logging results to tensorboard and stdout

    Params
    ------
    output_directory (string): directory to save checkpoints
    log_directory (string) directory to save tensorboard logs
    checkpoint_path(string): checkpoint path
    n_gpus (int): number of gpus
    rank (int): rank of current gpu
    hparams (object): comma separated list of "name=value" pairs.
    """
    hparams.n_gpus = n_gpus
    hparams.rank = rank
    if hparams.distributed_run:
        init_distributed(hparams, n_gpus, rank, group_name)

    torch.manual_seed(hparams.seed)
    torch.cuda.manual_seed(hparams.seed)

    train_loader, valset, collate_fn, train_sampler, trainset = prepare_dataloaders(hparams)
    speaker_lookup = trainset.speaker_ids
    
    if hparams.drop_frame_rate > 0.:
        if rank != 0: # if global_mean not yet calcuated, wait for main thread to do it
            while not os.path.exists(hparams.global_mean_npy): time.sleep(1)
        global_mean = calculate_global_mean(train_loader, hparams.global_mean_npy, hparams)
        hparams.global_mean = global_mean
    
    model = load_model(hparams)

    model.eval() # test if this is needed anymore

    learning_rate = hparams.learning_rate
	if hparams.Apex_optimizer: # apex optimizer is slightly faster with slightly more vram usage in my testing. Helps in both fp32 and fp16.
    	optimizer = apexopt.FusedAdam(model.parameters(), lr=learning_rate, weight_decay=hparams.weight_decay)
	else:
	    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=hparams.weight_decay)

    if hparams.fp16_run:
        model, optimizer = amp.initialize(model, optimizer, opt_level='O2')

    if hparams.distributed_run:
        model = apply_gradient_allreduce(model)

    criterion = Tacotron2Loss(hparams)

    logger = prepare_directories_and_logger(
        output_directory, log_directory, rank)

    # Load checkpoint if one exists
    best_validation_loss = 0.8 # used to see when "best_model" should be saved, default = 0.8, load_checkpoint will update to last best value.
    iteration = 0
    epoch_offset = 0
    if checkpoint_path is not None:
        if warm_start:
            model, iteration = warm_start_model(
                checkpoint_path, model, hparams.ignore_layers)
        elif warm_start_force:
            model, iteration = warm_start_force_model(
                checkpoint_path, model)
        else:
            model, optimizer, _learning_rate, iteration, best_validation_loss = load_checkpoint(
                checkpoint_path, model, optimizer)
            if hparams.use_saved_learning_rate:
                learning_rate = _learning_rate
            iteration += 1  # next iteration is iteration + 1
            epoch_offset = max(0, int(iteration / len(train_loader)))
        print('Model Loaded')
	
    ## LEARNING RATE SCHEDULER
    if True:
        from torch.optim.lr_scheduler import ReduceLROnPlateau
        min_lr = 1e-5
        factor = 0.1**(1/5) # amount to scale the LR by on Validation Loss plateau
        scheduler = ReduceLROnPlateau(optimizer, 'min', factor=factor, patience=20, cooldown=2, min_lr=min_lr, verbose=True)
        print("ReduceLROnPlateau used as (optional) Learning Rate Scheduler.")
    else: scheduler=False

    model.train()
    is_overflow = False
	
    validate_then_terminate = 0 # I use this for testing old models with new metrics
    if validate_then_terminate:
        val_loss = validate(model, criterion, valset, iteration,
            hparams.batch_size, n_gpus, collate_fn, logger,
            hparams.distributed_run, rank)
        raise Exception("Finished Validation")
    
    for param_group in optimizer.param_groups:
        param_group['lr'] = learning_rate
    
    rolling_loss = StreamingMovingAverage(min(int(len(train_loader)), 200))
    # ================ MAIN TRAINNIG LOOP! ===================
    for epoch in tqdm(range(epoch_offset, hparams.epochs), initial=epoch_offset, total=hparams.epochs, desc="Epoch:", position=1, unit="epoch"):
        tqdm.write("Epoch:{}".format(epoch))
        
        if hparams.distributed_run: # shuffles the train_loader when doing multi-gpu training
            train_sampler.set_epoch(epoch)
        start_time = time.time()
        # start iterating through the epoch
        for i, batch in tqdm(enumerate(train_loader), desc="Iter:  ", smoothing=0, total=len(train_loader), position=0, unit="iter"):
                    # run external code every iter, allows the run to be adjusted without restarts
                    if (i==0 or iteration % param_interval == 0):
                        try:
                            with open("run_every_epoch.py") as f:
                                internal_text = str(f.read())
                                if len(internal_text) > 0:
                                    #code = compile(internal_text, "run_every_epoch.py", 'exec')
                                    ldict = {'iteration': iteration}
                                    exec(internal_text, globals(), ldict)
                                else:
                                    print("No Custom code found, continuing without changes.")
                        except Exception as ex:
                            print(f"Custom code FAILED to run!\n{ex}")
                        globals().update(ldict)
                        locals().update(ldict)
                        if show_live_params:
                            print(internal_text)
                    if not iteration % 50: # check actual learning rate every 20 iters (because I sometimes see learning_rate variable go out-of-sync with real LR)
                        learning_rate = optimizer.param_groups[0]['lr']
                    # Learning Rate Schedule
                    if custom_lr:
                        old_lr = learning_rate
                        if iteration < warmup_start:
                            learning_rate = warmup_start_lr
                        elif iteration < warmup_end:
                            learning_rate = (iteration-warmup_start)*((A_+C_)-warmup_start_lr)/(warmup_end-warmup_start) + warmup_start_lr # learning rate increases from warmup_start_lr to A_ linearly over (warmup_end-warmup_start) iterations.
                        else:
                            if iteration < decay_start:
                                learning_rate = A_ + C_
                            else:
                                iteration_adjusted = iteration - decay_start
                                learning_rate = (A_*(e**(-iteration_adjusted/B_))) + C_
                        assert learning_rate > -1e-8, "Negative Learning Rate."
                        if old_lr != learning_rate:
            for param_group in optimizer.param_groups:
                param_group['lr'] = learning_rate
                    else:
                        scheduler.patience = scheduler_patience
                        scheduler.cooldown = scheduler_cooldown
                        if override_scheduler_last_lr:
                            scheduler._last_lr = override_scheduler_last_lr
                            print("Scheduler last_lr overriden. scheduler._last_lr =", scheduler._last_lr)
                        if override_scheduler_best:
                            scheduler.best = override_scheduler_best
                            print("Scheduler best metric overriden. scheduler.best =", override_scheduler_best)
예제 #20
0
    def run(self, checkpoint=None):

        opt = self.opt

        if checkpoint is not None:
            raise NotImplementedError

            # TODO: have loading checkpoints for each process

            self.model.load_state_dict(checkpoint['model'])
            prec_opt = checkpoint['opt'] if 'opt' in checkpoint else None

            opt.reset_optim = True

            if not opt.reset_optim:

                if self.is_main():
                    print("* Loading optimizer states ... ")
                self.optim.load_state_dict(checkpoint['optim'])
                if prec_opt is not None and hasattr(prec_opt, "fp16_mixed"):
                    # Only load amp information if the mode is the same
                    # Maybe its better to change between optimization mode?
                    if opt.fp16_mixed == prec_opt.fp16_mixed and opt.fp16 == prec_opt.fp16:
                        if 'amp' in checkpoint:
                            try:
                                amp.load_state_dict(checkpoint['amp'])
                            except Exception:
                                # loading the amp state can fail
                                pass

                # Only load the progress when we use the same optimizer
                if 'itr' in checkpoint:
                    itr_progress = checkpoint['itr']
                else:
                    itr_progress = None

                resume = True
                start_epoch = checkpoint[
                    'epoch'] if 'epoch' in checkpoint else 1
                if start_epoch is None:
                    start_epoch = 1
            else:
                itr_progress = None
                resume = False
                start_epoch = 1

            del checkpoint['model']
            optim_state_dict = checkpoint['optim']
            # del checkpoint['optim']
            del checkpoint
        else:
            itr_progress = None

            resume = False
            start_epoch = 1

        if opt.load_encoder_from:
            self.load_encoder_weight(opt.load_encoder_from)
        #
        if opt.load_decoder_from:
            self.load_decoder_weight(opt.load_decoder_from)

        # if we are on a GPU: warm up the memory allocator
        if self.cuda:
            self.warm_up()

        valid_loss = self.eval(self.valid_data)
        valid_ppl = math.exp(min(valid_loss, 100))

        if self.is_main():
            print('[INFO] Validation perplexity: %g' % valid_ppl, flush=True)

        self.start_time = time.time()

        for epoch in range(start_epoch, start_epoch + opt.epochs):
            self.print('')

            #  (1) train for one epoch on the training set
            train_loss = self.train_epoch(epoch,
                                          resume=resume,
                                          itr_progress=itr_progress)
            train_ppl = math.exp(min(train_loss, 100))
            self.print('[INFO] Train perplexity: %g' % train_ppl)

            #  (2) evaluate on the validation set
            valid_loss = self.eval(self.valid_data)
            valid_ppl = math.exp(min(valid_loss, 100))

            if self.is_main():
                print('[INFO] Validation perplexity: %g' % valid_ppl)
                self.save(epoch, valid_ppl)

            itr_progress = None
            resume = False
예제 #21
0
def main():
    args = parse_args()
    utils.gpu_affinity.set_affinity(args.local_rank)

    # Initialize device and distributed backend
    torch.cuda.set_device(args.local_rank)
    l2_promote()
    device = torch.device('cuda' if args.cuda else 'cpu')
    utils.distributed.init_distributed(args.cuda)

    args.work_dir = utils.exp_utils.build_work_dir_name(args.work_dir,
                                                        args.dataset,
                                                        args.append_dataset,
                                                        args.append_time,
                                                        )

    with utils.distributed.sync_workers() as rank:
        if rank == 0:
            create_exp_dir(args.work_dir,
                           scripts_to_save=['train.py', 'mem_transformer.py'],
                           debug=args.debug)

    # Setup logging
    if args.log_all_ranks:
        log_file = f'train_log_rank_{utils.distributed.get_rank()}.log'
    else:
        log_file = args.txtlog_file
    dllog_file = args.dllog_file
    log_file = os.path.join(args.work_dir, log_file)
    dllog_file = os.path.join(args.work_dir, dllog_file)

    if args.debug:
        log_file = os.devnull
        dllog_file = os.devnull

    utils.exp_utils.setup_logging(log_all_ranks=args.log_all_ranks,
                                  filename=log_file,
                                  )
    utils.exp_utils.setup_dllogger(enabled=True, filename=dllog_file)

    if args.local_batch_size is not None:
        world_size = utils.distributed.get_world_size()
        args.batch_size = world_size * args.local_batch_size
        logging.info(f'--local_batch_size was set, adjusting global batch size'
                     f' to {args.batch_size} (local_batch_size * world_size)')

    logging.info(args)
    dllogger.log(step='PARAMETER', data=vars(args))

    logging.info(f'world size: {utils.distributed.get_world_size()}')

    if not args.no_env:
        log_env_info()

    register_ignoring_timeout_handler()

    # Set the random seed manually for reproducibility.
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    ###########################################################################
    # Load data
    ###########################################################################
    corpus = get_lm_corpus(args.data, args.dataset, args.vocab)
    ntokens = len(corpus.vocab)
    vocab = corpus.vocab
    args.n_token = ntokens

    if args.mem_len == 0:
        eval_mem_len = 0
    else:
        eval_mem_len = args.mem_len + args.tgt_len - args.eval_tgt_len

    tr_iter = corpus.get_iterator('train', args.batch_size, args.tgt_len,
                                  device=device, ext_len=args.ext_len)
    va_iter = corpus.get_iterator('valid', args.eval_batch_size,
                                  args.eval_tgt_len, device=device,
                                  mem_len=eval_mem_len, ext_len=args.ext_len)
    te_iter = corpus.get_iterator('test', args.eval_batch_size,
                                  args.eval_tgt_len, device=device,
                                  mem_len=eval_mem_len, ext_len=args.ext_len)

    # adaptive softmax / embedding
    cutoffs, tie_projs = [], [False]
    if args.adaptive:
        assert args.dataset in ['wt103', 'lm1b']
        if args.dataset == 'wt103':
            cutoffs = [19997, 39997, 199997]
            tie_projs += [True] * len(cutoffs)
        elif args.dataset == 'lm1b':
            cutoffs = [59997, 99997, 639997]
            tie_projs += [False] * len(cutoffs)

    ###########################################################################
    # Build the model
    ###########################################################################
    model_config = {
        'n_token': ntokens,
        'n_layer': args.n_layer,
        'n_head': args.n_head,
        'd_model': args.d_model,
        'd_head': args.d_head,
        'd_inner': args.d_inner,
        'dropout': args.dropout,
        'dropatt': args.dropatt,
        'dtype': None,
        'tie_weight': args.tied,
        'd_embed': args.d_embed,
        'div_val': args.div_val,
        'tie_projs': tie_projs,
        'pre_lnorm': args.pre_lnorm,
        'tgt_len': args.tgt_len,
        'ext_len': args.ext_len,
        'mem_len': args.mem_len,
        'cutoffs': cutoffs,
        'same_length': args.same_length,
        'attn_type': args.attn_type,
        'clamp_len': args.clamp_len,
        'sample_softmax': args.sample_softmax,
        }

    model = MemTransformerLM(**model_config)

    model.apply(functools.partial(weights_init, args=args))
    # ensure embedding init is not overridden by out_layer in case of weight sharing
    model.word_emb.apply(functools.partial(weights_init, args=args))

    args.n_all_param = sum([p.nelement() for p in model.parameters()])
    args.n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()])

    # optimizer
    if args.optim.lower() == 'sgd':
        if args.sample_softmax > 0:
            dense_params, sparse_params = [], []
            for param in model.parameters():
                if param.size() == model.word_emb.weight.size():
                    sparse_params.append(param)
                else:
                    dense_params.append(param)
            optimizer_sparse = optim.SGD(sparse_params, lr=args.lr * 2)
            optimizer = optim.SGD(dense_params, lr=args.lr, momentum=args.mom)
        else:
            optimizer = optim.SGD(model.parameters(), lr=args.lr,
                                  momentum=args.mom)
            optimizer_sparse = None
    elif args.optim.lower() == 'adam':
        if args.sample_softmax > 0:
            dense_params, sparse_params = [], []
            for param in model.parameters():
                if param.size() == model.word_emb.weight.size():
                    sparse_params.append(param)
                else:
                    dense_params.append(param)
            optimizer_sparse = optim.SparseAdam(sparse_params, lr=args.lr)
            optimizer = optim.Adam(dense_params, lr=args.lr,
                                   weight_decay=args.weight_decay)
        else:
            optimizer = optim.Adam(model.parameters(), lr=args.lr,
                                   weight_decay=args.weight_decay)
            optimizer_sparse = None
    elif args.optim.lower() == 'adagrad':
        optimizer = optim.Adagrad(model.parameters(), lr=args.lr)
        optimizer_sparse = None
    elif args.optim.lower() == 'lamb':
        optimizer = lamb.Lamb(model.parameters(), lr=args.lr,
                              weight_decay=args.weight_decay)
        optimizer_sparse = None
    elif args.optim.lower() == 'jitlamb':
        optimizer = lamb.JITLamb(model.parameters(), lr=args.lr,
                                 weight_decay=args.weight_decay)
        optimizer_sparse = None

    model = model.to(device)

    scaler = None
    if args.fp16:
        if args.amp == 'pytorch':
            scaler = torch.cuda.amp.GradScaler()
        elif args.amp == 'apex':
            model, optimizer = amp.initialize(
                model,
                optimizer,
                opt_level=args.apex_amp_opt_level,
                )

    if args.multi_gpu == 'ddp' and torch.distributed.is_initialized():
        para_model = DistributedDataParallel(model,
                                             device_ids=[args.local_rank],
                                             output_device=args.local_rank,
                                             broadcast_buffers=False,
                                             find_unused_parameters=True,
                                             )
    elif args.multi_gpu == 'dp':
        if args.gpu0_bsz >= 0:
            para_model = BalancedDataParallel(args.gpu0_bsz // args.batch_chunk,
                                              model, dim=1).to(device)
        else:
            para_model = nn.DataParallel(model, dim=1).to(device)
    else:
        para_model = model

    # scheduler
    if args.scheduler == 'cosine':
        if args.max_step_scheduler:
            max_step = args.max_step_scheduler
        else:
            max_step = args.max_step

        scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer, max_step - args.warmup_step, eta_min=args.eta_min)
        if args.sample_softmax > 0 and optimizer_sparse is not None:
            scheduler_sparse = optim.lr_scheduler.CosineAnnealingLR(
                optimizer_sparse, max_step - args.warmup_step,
                eta_min=args.eta_min)
        else:
            scheduler_sparse = None
    elif args.scheduler == 'inv_sqrt':
        # originally used for Transformer (in Attention is all you need)
        def lr_lambda(step):
            # return a multiplier instead of a learning rate
            if step == 0 and args.warmup_step == 0:
                return 1.
            else:
                return 1. / (step ** 0.5) if step > args.warmup_step \
                    else step / (args.warmup_step ** 1.5)
        scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
        if args.sample_softmax > 0 and optimizer_sparse is not None:
            scheduler_sparse = optim.lr_scheduler.LambdaLR(
                optimizer_sparse,
                lr_lambda=lr_lambda
                )
        else:
            scheduler_sparse = None
    elif args.scheduler == 'dev_perf':
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, factor=args.decay_rate, patience=args.patience,
            min_lr=args.lr_min,
            )
        if args.sample_softmax > 0 and optimizer_sparse is not None:
            scheduler_sparse = optim.lr_scheduler.ReduceLROnPlateau(
                optimizer_sparse, factor=args.decay_rate, patience=args.patience,
                min_lr=args.lr_min,
                )
        else:
            scheduler_sparse = None
    elif args.scheduler == 'constant':
        pass

    logging.info('=' * 100)
    for k, v in args.__dict__.items():
        logging.info('    - {} : {}'.format(k, v))
    logging.info('=' * 100)
    logging.info('#params = {}'.format(args.n_all_param))
    logging.info('#non emb params = {}'.format(args.n_nonemb_param))

    train_step = 0
    start_epoch = 1
    last_batch = 0
    last_iter = 0
    best_val_loss = None

    if args.restart:
        try:
            checkpoint = load_checkpoint(args.restart)
            model.load_state_dict(checkpoint['model_state'])
            optimizer.load_state_dict(checkpoint['optimizer_state'])
            scheduler.load_state_dict(checkpoint['scheduler_state'])
            if args.fp16:
                if args.amp == 'pytorch':
                    scaler.load_state_dict(checkpoint['amp_state'])
                elif args.amp == 'apex':
                    amp.load_state_dict(checkpoint['amp_state'])
            train_step = checkpoint['train_step']
            start_epoch = checkpoint['epoch']
            last_batch = checkpoint['batch']
            last_iter = checkpoint['last_iter']
            best_val_loss = checkpoint['best_val_loss']

            if train_step >= args.max_step:
                logging.info(f'Loaded checkpoint after {train_step} steps, but '
                             f'this run was scheduled for a total of '
                             f'{args.max_step} steps, exiting')
                sys.exit(1)

            model.apply(functools.partial(update_dropout, args=args))
            model.apply(functools.partial(update_dropatt, args=args))
        except FileNotFoundError:
            logging.info(f'Could not load checkpoint from {args.restart}, '
                         f'starting training from random init')

    meters = {}
    warmup = args.mem_len // args.tgt_len + 2
    meters['train_throughput'] = AverageMeter(warmup=warmup)
    ###########################################################################
    # Train
    ###########################################################################
    # Loop over epochs.
    # At any point you can hit Ctrl + C to break out of training early.
    start_time = time.time()
    with TimeoutHandler() as timeout_handler:
        try:
            for epoch in itertools.count(start=start_epoch):
                if args.roll:
                    tr_iter.roll(seed=args.seed + epoch)
                train_step, best_val_loss = train(
                    tr_iter, va_iter, model, para_model, model_config,
                    optimizer, optimizer_sparse, scheduler, scheduler_sparse,
                    scaler, vocab, epoch, last_batch, last_iter, train_step,
                    best_val_loss, meters, timeout_handler, device, args
                    )

                last_batch = 0
                last_iter = 0

                if train_step == args.max_step:
                    logging.info('-' * 100)
                    logging.info('End of training')
                    break
        except KeyboardInterrupt:
            logging.info('-' * 100)
            logging.info('Exiting from training early')
    elapsed = time.time() - start_time

    ###########################################################################
    # Test
    ###########################################################################
    summary = {}
    test_path = os.path.join(args.work_dir, 'checkpoint_best.pt')
    if not args.debug and not args.no_eval and os.path.exists(test_path):
        # Load the best saved model.
        checkpoint = load_checkpoint(test_path)
        model.load_state_dict(checkpoint['model_state'])

        # Run on test data.
        test_start_time = time.time()
        test_loss = evaluate(te_iter, model, args)
        test_loss = utils.distributed.all_reduce_item(test_loss, 'mean')
        test_elapsed = time.time() - test_start_time

        logging.info('=' * 100)
        if args.dataset in ['enwik8', 'text8']:
            logging.info('| End of training | test time: {:5.2f}s | test loss {:5.2f} | test bpc {:9.5f}'.format(
                test_elapsed, test_loss, test_loss / math.log(2)))
        else:
            logging.info('| End of training | test time: {:5.2f}s | test loss {:5.2f} | test ppl {:9.3f}'.format(
                test_elapsed, test_loss, math.exp(test_loss)))
        logging.info('=' * 100)

        summary.update({
            'test_elapsed': test_elapsed,
            'test_loss': test_loss,
            })

        if args.dataset in ['enwik8', 'text8']:
            summary['test_bits_per_character'] = test_loss / math.log(2)
        else:
            summary['test_perplexity'] = math.exp(test_loss)

    logging.info(f'Training time: {(elapsed / 60):.2f} minutes')
    logging.info(f'Training throughput: {meters["train_throughput"].avg:.2f} tok/s')

    if best_val_loss:
        val_perplexity = math.exp(best_val_loss)
    else:
        val_perplexity = None

    summary.update({
        'train_throughput': meters['train_throughput'].avg,
        'train_elapsed': elapsed / 60,
        'valid_loss': best_val_loss,
        'valid_perplexity': val_perplexity,
        })
    dllogger.log(step=tuple(), data=summary)

    passed = benchmark(
        target_perplexity=args.target_perplexity,
        test_perplexity=val_perplexity,
        target_throughput=args.target_throughput,
        test_throughput=meters['train_throughput'].avg
        )
    if not passed:
        sys.exit(1)
def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu

    # suppress printing if not master
    if args.multiprocessing_distributed and args.gpu != 0:

        def print_pass(*args):
            pass

        builtins.print = print_pass

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)
    # create model
    print("=> creating model '{}'".format(args.arch))
    model = builder.MoCo(ResNet.__dict__[args.arch],
                         args.moco_dim,
                         args.moco_k,
                         args.moco_m,
                         args.moco_t,
                         args.mlp,
                         two_branch=True,
                         normlinear=args.normlinear)
    print(model)

    if args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model.cuda(args.gpu)

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.amp_opt_level != "O0":
        if amp is None:
            print(
                "apex is not installed but amp_opt_level is set to {args.amp_opt_level}, ignoring.\n"
                "you should install apex from https://github.com/NVIDIA/apex#quick-start first"
            )
            args.amp_opt_level = "O0"
        else:
            model, optimizer = amp.initialize(model.cuda(),
                                              optimizer,
                                              opt_level=args.amp_opt_level)

    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int(
                (args.workers + ngpus_per_node - 1) / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.gpu])
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
        # comment out the following line for debugging
        raise NotImplementedError("Only DistributedDataParallel is supported.")
    else:
        # AllGather implementation (batch shuffle, queue update, etc.) in
        # this code only supports DistributedDataParallel.
        raise NotImplementedError("Only DistributedDataParallel is supported.")

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda(args.gpu)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            if args.gpu is None:
                checkpoint = torch.load(args.resume)
            else:
                # Map model to be loaded to specified single gpu.
                loc = 'cuda:{}'.format(args.gpu)
                checkpoint = torch.load(args.resume, map_location=loc)
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'].state_dict())
            # optimizer = checkpoint['optimizer']
            if args.amp_opt_level != "O0" and checkpoint[
                    'args'].amp_opt_level != "O0":
                amp.load_state_dict(checkpoint['amp'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join(args.data, 'bps-train')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    if args.aug_plus:
        # MoCo v2's aug: similar to SimCLR https://arxiv.org/abs/2002.05709
        augmentation = [
            ToTensor3D(),
            transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
            transforms.RandomApply(
                [
                    transforms.ColorJitter(0.4, 0.4, 0.4,
                                           0.1)  # not strengthened
                ],
                p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.RandomApply(
                [transforms.GaussianBlur(kernel_size=23, sigma=(0.1, 2.0))],
                p=0.5),
            normalize
        ]
    else:
        # MoCo v1's aug: the same as InstDisc https://arxiv.org/abs/1805.01978
        augmentation = [
            ToTensor3D(),
            transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
            transforms.RandomGrayscale(p=0.2),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
            transforms.RandomHorizontalFlip(), normalize
        ]

    train_dataset = datasets.ImageFolder(traindir,
                                         loader.ThreeCropsTransform(
                                             transforms.Compose(augmentation)),
                                         loader=sk_loader)

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler,
                                               drop_last=True)

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch, args)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, args)

        if not args.multiprocessing_distributed or (
                args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):
            state = {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer,
                'args': args,
            }
            if args.amp_opt_level != 'O0':
                state['amp'] = amp.state_dict()
            save_checkpoint(state, is_best=False, save_dir=args.save_dir, \
                filename='checkpoint_{:04d}.pth.tar'.format(epoch), epoch=epoch)
예제 #23
0
def make_optimizer_and_schedule(args, model, checkpoint, params):
    """
    *Internal Function* (called directly from train_model)

    Creates an optimizer and a schedule for a given model, restoring from a
    checkpoint if it is non-null.

    Args:
        args (object) : an arguments object, see
            :meth:`~robustness.train.train_model` for details
        model (AttackerModel) : the model to create the optimizer for
        checkpoint (dict) : a loaded checkpoint saved by this library and loaded
            with `ch.load`
        params (list|None) : a list of parameters that should be updatable, all
            other params will not update. If ``None``, update all params 

    Returns:
        An optimizer (ch.nn.optim.Optimizer) and a scheduler
            (ch.nn.optim.lr_schedulers module).
    """
    # Make optimizer
    param_list = model.parameters() if params is None else params
    optimizer = SGD(param_list,
                    args.lr,
                    momentum=args.momentum,
                    weight_decay=args.weight_decay)

    if args.mixed_precision:
        model.to('cuda')
        model, optimizer = amp.initialize(model, optimizer, 'O1')

    # Make schedule
    schedule = None
    if args.custom_lr_multiplier == 'cyclic':
        eps = args.epochs
        lr_func = lambda t: np.interp([t], [0, eps * 4 // 15, eps], [0, 1, 0])[
            0]
        schedule = lr_scheduler.LambdaLR(optimizer, lr_func)
    elif args.custom_lr_multiplier:
        cs = args.custom_lr_multiplier
        periods = eval(cs) if type(cs) is str else cs
        if args.lr_interpolation == 'linear':
            lr_func = lambda t: np.interp([t], *zip(*periods))[0]
        else:

            def lr_func(ep):
                for (milestone, lr) in reversed(periods):
                    if ep >= milestone: return lr
                return 1.0

        schedule = lr_scheduler.LambdaLR(optimizer, lr_func)
    elif args.step_lr:
        schedule = lr_scheduler.StepLR(optimizer,
                                       step_size=args.step_lr,
                                       gamma=args.step_lr_gamma)

    # Fast-forward the optimizer and the scheduler if resuming
    if checkpoint:
        optimizer.load_state_dict(checkpoint['optimizer'])
        try:
            schedule.load_state_dict(checkpoint['schedule'])
        except:
            steps_to_take = checkpoint['epoch']
            print('Could not load schedule (was probably LambdaLR).'
                  f' Stepping {steps_to_take} times instead...')
            for i in range(steps_to_take):
                schedule.step()

        if 'amp' in checkpoint and checkpoint['amp'] is not None:
            amp.load_state_dict(checkpoint['amp'])

        # TODO: see if there's a smarter way to do this
        # TODO: see what's up with loading fp32 weights and then MP training
        if args.mixed_precision:
            model.load_state_dict(checkpoint['model'])

    return optimizer, schedule
예제 #24
0
def main():
    setup_default_logging()
    args, args_text = _parse_args()

    args.prefetcher = not args.no_prefetcher
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
        if args.distributed and args.num_gpu > 1:
            logging.warning(
                'Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.'
            )
            args.num_gpu = 1

    args.device = 'cuda:0'
    args.world_size = 1
    args.rank = 0  # global rank
    if args.distributed:
        args.num_gpu = 1
        args.device = 'cuda:%d' % args.local_rank
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()
        args.rank = torch.distributed.get_rank()
    assert args.rank >= 0

    if args.distributed:
        logging.info(
            'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
            % (args.rank, args.world_size))
    else:
        logging.info('Training with a single process on %d GPUs.' %
                     args.num_gpu)

    torch.manual_seed(args.seed + args.rank)

    model = create_model(args.model,
                         pretrained=args.pretrained,
                         num_classes=args.num_classes,
                         drop_rate=args.drop,
                         drop_connect_rate=args.drop_connect,
                         global_pool=args.gp,
                         bn_tf=args.bn_tf,
                         bn_momentum=args.bn_momentum,
                         bn_eps=args.bn_eps,
                         checkpoint_path=args.initial_checkpoint)

    if args.local_rank == 0:
        logging.info('Model %s created, param count: %d' %
                     (args.model, sum([m.numel()
                                       for m in model.parameters()])))

    data_config = resolve_data_config(vars(args),
                                      model=model,
                                      verbose=args.local_rank == 0)

    if args.num_gpu > 1:
        if args.amp:
            logging.warning(
                'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.'
            )
            args.amp = False
        model = nn.DataParallel(model,
                                device_ids=list(range(args.num_gpu))).cuda()
    else:
        model.cuda()

    optimizer = create_optimizer(args, model)

    use_amp = False
    if has_apex and args.amp:
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        use_amp = True
    if args.local_rank == 0:
        logging.info('NVIDIA APEX {}. AMP {}.'.format(
            'installed' if has_apex else 'not installed',
            'on' if use_amp else 'off'))

    # optionally resume from a checkpoint
    resume_state = {}
    resume_epoch = None
    if args.resume:
        resume_state, resume_epoch = resume_checkpoint(model, args.resume)
    if resume_state and not args.no_resume_opt:
        if 'optimizer' in resume_state:
            if args.local_rank == 0:
                logging.info('Restoring Optimizer state from checkpoint')
            optimizer.load_state_dict(resume_state['optimizer'])
        if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__:
            if args.local_rank == 0:
                logging.info('Restoring NVIDIA AMP state from checkpoint')
            amp.load_state_dict(resume_state['amp'])
    resume_state = None  # clear it

    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
        model_ema = ModelEma(model,
                             decay=args.model_ema_decay,
                             device='cpu' if args.model_ema_force_cpu else '',
                             resume=args.resume)

    if args.distributed:
        if args.sync_bn:
            try:
                if has_apex:
                    model = convert_syncbn_model(model)
                else:
                    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                        model)
                if args.local_rank == 0:
                    logging.info(
                        'Converted model to use Synchronized BatchNorm.')
            except Exception as e:
                logging.error(
                    'Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1'
                )
        if has_apex:
            model = DDP(model, delay_allreduce=True)
        else:
            if args.local_rank == 0:
                logging.info(
                    "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP."
                )
            model = DDP(model,
                        device_ids=[args.local_rank
                                    ])  # can use device str in Torch >= 1.1
        # NOTE: EMA model does not need to be wrapped by DDP

    lr_scheduler, num_epochs = create_scheduler(args, optimizer)
    start_epoch = 0
    if args.start_epoch is not None:
        # a specified start_epoch will always override the resume epoch
        start_epoch = args.start_epoch
    elif resume_epoch is not None:
        start_epoch = resume_epoch
    if lr_scheduler is not None and start_epoch > 0:
        lr_scheduler.step(start_epoch)

    if args.local_rank == 0:
        logging.info('Scheduled epochs: {}'.format(num_epochs))

    train_dir = os.path.join(args.data, 'train')
    if not os.path.exists(train_dir):
        logging.error(
            'Training folder does not exist at: {}'.format(train_dir))
        exit(1)
    dataset_train = Dataset(train_dir)

    collate_fn = None
    if args.prefetcher and args.mixup > 0:
        collate_fn = FastCollateMixup(args.mixup, args.smoothing,
                                      args.num_classes)

    loader_train = create_loader(
        dataset_train,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        is_training=True,
        use_prefetcher=args.prefetcher,
        rand_erase_prob=args.reprob,
        rand_erase_mode=args.remode,
        rand_erase_count=args.recount,
        color_jitter=args.color_jitter,
        auto_augment=args.aa,
        interpolation=
        'random',  # FIXME cleanly resolve this? data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        collate_fn=collate_fn,
    )

    eval_dir = os.path.join(args.data, 'val')
    if not os.path.isdir(eval_dir):
        eval_dir = os.path.join(args.data, 'validation')
        if not os.path.isdir(eval_dir):
            logging.error(
                'Validation folder does not exist at: {}'.format(eval_dir))
            exit(1)
    dataset_eval = Dataset(eval_dir)

    loader_eval = create_loader(
        dataset_eval,
        input_size=data_config['input_size'],
        batch_size=4 * args.batch_size,
        is_training=False,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
    )

    if args.mixup > 0.:
        # smoothing is handled with mixup label transform
        train_loss_fn = SoftTargetCrossEntropy().cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    elif args.smoothing:
        train_loss_fn = LabelSmoothingCrossEntropy(
            smoothing=args.smoothing).cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    else:
        train_loss_fn = nn.CrossEntropyLoss().cuda()
        validate_loss_fn = train_loss_fn

    eval_metric = args.eval_metric
    best_metric = None
    best_epoch = None
    saver = None
    output_dir = ''
    if args.local_rank == 0:
        output_base = args.output if args.output else './output'
        exp_name = '-'.join([
            datetime.now().strftime("%Y%m%d-%H%M%S"), args.model,
            str(data_config['input_size'][-1])
        ])
        output_dir = get_outdir(output_base, 'train', exp_name)
        decreasing = True if eval_metric == 'loss' else False
        saver = CheckpointSaver(checkpoint_dir=output_dir,
                                decreasing=decreasing)
        with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
            f.write(args_text)

    try:
        for epoch in range(start_epoch, num_epochs):
            if args.distributed:
                loader_train.sampler.set_epoch(epoch)

            train_metrics = train_epoch(epoch,
                                        model,
                                        loader_train,
                                        optimizer,
                                        train_loss_fn,
                                        args,
                                        lr_scheduler=lr_scheduler,
                                        saver=saver,
                                        output_dir=output_dir,
                                        use_amp=use_amp,
                                        model_ema=model_ema)

            eval_metrics = validate(model, loader_eval, validate_loss_fn, args)

            if model_ema is not None and not args.model_ema_force_cpu:
                ema_eval_metrics = validate(model_ema.ema,
                                            loader_eval,
                                            validate_loss_fn,
                                            args,
                                            log_suffix=' (EMA)')
                eval_metrics = ema_eval_metrics

            if lr_scheduler is not None:
                # step LR for next epoch
                lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])

            update_summary(epoch,
                           train_metrics,
                           eval_metrics,
                           os.path.join(output_dir, 'summary.csv'),
                           write_header=best_metric is None)

            if saver is not None:
                # save proper checkpoint with eval metric
                save_metric = eval_metrics[eval_metric]
                best_metric, best_epoch = saver.save_checkpoint(
                    model,
                    optimizer,
                    args,
                    epoch=epoch,
                    model_ema=model_ema,
                    metric=save_metric,
                    use_amp=use_amp)

    except KeyboardInterrupt:
        pass
    if best_metric is not None:
        logging.info('*** Best metric: {0} (epoch {1})'.format(
            best_metric, best_epoch))
예제 #25
0
def train_and_fit(args):

    if args.fp16:
        from apex import amp
    else:
        amp = None

    cuda = torch.cuda.is_available()

    train_loader, test_loader, train_len, test_len = load_dataloaders(args)
    logger.info("Loaded %d Training samples." % train_len)

    if args.model_no == 0:
        from ..model.BERT.modeling_bert import BertModel as Model
        model = args.model_size  #'bert-base-uncased'
        lower_case = True
        model_name = 'BERT'
        net = Model.from_pretrained(model, force_download=False, \
                                model_size=args.model_size,
                                task='classification' if args.task != 'fewrel' else 'fewrel',\
                                n_classes_=args.num_classes)
    elif args.model_no == 1:
        from ..model.ALBERT.modeling_albert import AlbertModel as Model
        model = args.model_size  #'albert-base-v2'
        lower_case = True
        model_name = 'ALBERT'
        net = Model.from_pretrained(model, force_download=False, \
                                model_size=args.model_size,
                                task='classification' if args.task != 'fewrel' else 'fewrel',\
                                n_classes_=args.num_classes)
    elif args.model_no == 2:  # BioBert
        from ..model.BERT.modeling_bert import BertModel, BertConfig
        model = 'bert-base-uncased'
        lower_case = False
        model_name = 'BioBERT'
        config = BertConfig.from_pretrained(
            './additional_models/biobert_v1.1_pubmed/bert_config.json')
        net = BertModel.from_pretrained(pretrained_model_name_or_path='./additional_models/biobert_v1.1_pubmed/biobert_v1.1_pubmed.bin',
                                          config=config,
                                          force_download=False, \
                                          model_size='bert-base-uncased',
                                          task='classification' if args.task != 'fewrel' else 'fewrel',\
                                          n_classes_=args.num_classes)

    tokenizer = load_pickle("%s_tokenizer.pkl" % model_name)
    net.resize_token_embeddings(len(tokenizer))
    e1_id = tokenizer.convert_tokens_to_ids('[E1]')
    e2_id = tokenizer.convert_tokens_to_ids('[E2]')
    assert e1_id != e2_id != 1

    if cuda:
        net.cuda()

    logger.info("FREEZING MOST HIDDEN LAYERS...")
    if args.model_no == 0:
        unfrozen_layers = ["classifier", "pooler", "encoder.layer.11", \
                           "classification_layer", "blanks_linear", "lm_linear", "cls"]
    elif args.model_no == 1:
        unfrozen_layers = ["classifier", "pooler", "classification_layer",\
                           "blanks_linear", "lm_linear", "cls",\
                           "albert_layer_groups.0.albert_layers.0.ffn"]
    elif args.model_no == 2:
        unfrozen_layers = ["classifier", "pooler", "encoder.layer.11", \
                           "classification_layer", "blanks_linear", "lm_linear", "cls"]

    for name, param in net.named_parameters():
        if not any([layer in name for layer in unfrozen_layers]):
            print("[FROZE]: %s" % name)
            param.requires_grad = False
        else:
            print("[FREE]: %s" % name)
            param.requires_grad = True

    if args.use_pretrained_blanks == 1:
        logger.info(
            "Loading model pre-trained on blanks at ./data/test_checkpoint_%d.pth.tar..."
            % args.model_no)
        checkpoint_path = "./data/test_checkpoint_%d.pth.tar" % args.model_no
        checkpoint = torch.load(checkpoint_path)
        model_dict = net.state_dict()
        pretrained_dict = {
            k: v
            for k, v in checkpoint['state_dict'].items()
            if k in model_dict.keys()
        }
        model_dict.update(pretrained_dict)
        net.load_state_dict(pretrained_dict, strict=False)
        del checkpoint, pretrained_dict, model_dict

    criterion = nn.CrossEntropyLoss(ignore_index=-1)
    optimizer = optim.Adam([{"params": net.parameters(), "lr": args.lr}])

    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[2,4,6,8,12,15,18,20,22,\
                                                                      24,26,30], gamma=0.8)

    start_epoch, best_pred, amp_checkpoint = load_state(net,
                                                        optimizer,
                                                        scheduler,
                                                        args,
                                                        load_best=False)

    if (args.fp16) and (amp is not None):
        logger.info("Using fp16...")
        net, optimizer = amp.initialize(net, optimizer, opt_level='O2')
        if amp_checkpoint is not None:
            amp.load_state_dict(amp_checkpoint)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[2,4,6,8,12,15,18,20,22,\
                                                                          24,26,30], gamma=0.8)

    losses_per_epoch, accuracy_per_epoch, test_f1_per_epoch = load_results(
        args.model_no)

    logger.info("Starting training process...")
    pad_id = tokenizer.pad_token_id
    mask_id = tokenizer.mask_token_id
    update_size = len(train_loader) // 1
    for epoch in range(start_epoch, args.num_epochs):
        start_time = time.time()
        net.train()
        total_loss = 0.0
        losses_per_batch = []
        total_acc = 0.0
        accuracy_per_batch = []
        for i, data in enumerate(train_loader, 0):
            x, e1_e2_start, labels, _, _, _ = data
            attention_mask = (x != pad_id).float()
            token_type_ids = torch.zeros((x.shape[0], x.shape[1])).long()

            if cuda:
                x = x.cuda()
                labels = labels.cuda()
                attention_mask = attention_mask.cuda()
                token_type_ids = token_type_ids.cuda()

            #classification_logits, v1v2 = net(x, token_type_ids=token_type_ids, attention_mask=attention_mask, Q=None,\
            #              e1_e2_start=e1_e2_start)
            classification_logits = net(x, token_type_ids=token_type_ids, attention_mask=attention_mask, Q=None, \
                                              e1_e2_start=e1_e2_start)

            #return classification_logits, labels, net, tokenizer # for debugging now

            loss = criterion(classification_logits, labels.squeeze(1))
            loss = loss / args.gradient_acc_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            if args.fp16:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    amp.master_params(optimizer), args.max_norm)
            else:
                grad_norm = clip_grad_norm_(net.parameters(), args.max_norm)

            if (i % args.gradient_acc_steps) == 0:
                optimizer.step()
                optimizer.zero_grad()

            total_loss += loss.item()
            total_acc += evaluate_(classification_logits, labels, \
                                   ignore_idx=-1)[0]

            if (i % update_size) == (update_size - 1):
                losses_per_batch.append(args.gradient_acc_steps * total_loss /
                                        update_size)
                accuracy_per_batch.append(total_acc / update_size)
                print(
                    '[Epoch: %d, %5d/ %d points] total loss, accuracy per batch: %.3f, %.3f'
                    % (epoch + 1, (i + 1) * args.batch_size, train_len,
                       losses_per_batch[-1], accuracy_per_batch[-1]))
                total_loss = 0.0
                total_acc = 0.0

        scheduler.step()
        results = evaluate_results(net, test_loader, pad_id, cuda)
        losses_per_epoch.append(sum(losses_per_batch) / len(losses_per_batch))
        accuracy_per_epoch.append(
            sum(accuracy_per_batch) / len(accuracy_per_batch))
        test_f1_per_epoch.append(results['f1'])
        print("Epoch finished, took %.2f seconds." %
              (time.time() - start_time))
        print("Losses at Epoch %d: %.7f" % (epoch + 1, losses_per_epoch[-1]))
        print("Train accuracy at Epoch %d: %.7f" %
              (epoch + 1, accuracy_per_epoch[-1]))
        print("Test f1 at Epoch %d: %.7f" % (epoch + 1, test_f1_per_epoch[-1]))

        if accuracy_per_epoch[-1] > best_pred:
            best_pred = accuracy_per_epoch[-1]
            torch.save({
                    'epoch': epoch + 1,\
                    'state_dict': net.state_dict(),\
                    'best_acc': accuracy_per_epoch[-1],\
                    'optimizer' : optimizer.state_dict(),\
                    'scheduler' : scheduler.state_dict(),\
                    'amp': amp.state_dict() if amp is not None else amp
                }, os.path.join("./data/" , "task_test_model_best_%d.pth.tar" % args.model_no))

        if (epoch % 1) == 0:
            save_as_pickle("task_test_losses_per_epoch_%d.pkl" % args.model_no,
                           losses_per_epoch)
            save_as_pickle(
                "task_train_accuracy_per_epoch_%d.pkl" % args.model_no,
                accuracy_per_epoch)
            save_as_pickle("task_test_f1_per_epoch_%d.pkl" % args.model_no,
                           test_f1_per_epoch)
            torch.save({
                    'epoch': epoch + 1,\
                    'state_dict': net.state_dict(),\
                    'best_acc': accuracy_per_epoch[-1],\
                    'optimizer' : optimizer.state_dict(),\
                    'scheduler' : scheduler.state_dict(),\
                    'amp': amp.state_dict() if amp is not None else amp
                }, os.path.join("./data/" , "task_test_checkpoint_%d.pth.tar" % args.model_no))

    logger.info("Finished Training!")
    fig = plt.figure(figsize=(20, 20))
    ax = fig.add_subplot(111)
    ax.scatter([e for e in range(len(losses_per_epoch))], losses_per_epoch)
    ax.tick_params(axis="both", length=2, width=1, labelsize=14)
    ax.set_xlabel("Epoch", fontsize=22)
    ax.set_ylabel("Training Loss per batch", fontsize=22)
    ax.set_title("Training Loss vs Epoch", fontsize=32)
    plt.savefig(
        os.path.join("./data/", "task_loss_vs_epoch_%d.png" % args.model_no))

    fig2 = plt.figure(figsize=(20, 20))
    ax2 = fig2.add_subplot(111)
    ax2.scatter([e for e in range(len(accuracy_per_epoch))],
                accuracy_per_epoch)
    ax2.tick_params(axis="both", length=2, width=1, labelsize=14)
    ax2.set_xlabel("Epoch", fontsize=22)
    ax2.set_ylabel("Training Accuracy", fontsize=22)
    ax2.set_title("Training Accuracy vs Epoch", fontsize=32)
    plt.savefig(
        os.path.join("./data/",
                     "task_train_accuracy_vs_epoch_%d.png" % args.model_no))

    fig3 = plt.figure(figsize=(20, 20))
    ax3 = fig3.add_subplot(111)
    ax3.scatter([e for e in range(len(test_f1_per_epoch))], test_f1_per_epoch)
    ax3.tick_params(axis="both", length=2, width=1, labelsize=14)
    ax3.set_xlabel("Epoch", fontsize=22)
    ax3.set_ylabel("Test F1 Accuracy", fontsize=22)
    ax3.set_title("Test F1 vs Epoch", fontsize=32)
    plt.savefig(
        os.path.join("./data/",
                     "task_test_f1_vs_epoch_%d.png" % args.model_no))

    return net
예제 #26
0
def train_model(dataloders, model, criterion, optimizer, num_epochs=25):
    since = time.time()
    use_gpu = torch.cuda.is_available()
    checkpoint = {
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'amp': amp.state_dict()
    }
    torch.save(checkpoint, 'best_amp_checkpoint.pt')
    best_acc = 0.0
    dataset_sizes = {
        'train': len(dataloders['train'].dataset),
        'valid': len(dataloders['valid'].dataset)
    }

    for epoch in range(num_epochs):
        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in tqdm(dataloders[phase]):
                if use_gpu:
                    inputs, labels = Variable(inputs.to(device)), Variable(
                        labels.to(device))
                else:
                    inputs, labels = Variable(inputs), Variable(labels)

                optimizer.zero_grad()

                outputs = model(inputs)
                _, preds = torch.max(outputs.data, 1)
                # return outputs, labels, preds
                loss = criterion(outputs, labels)

                if phase == 'train':
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                    optimizer.step()

                running_loss += loss.data.item()
                running_corrects += torch.sum(
                    preds == labels.data).cpu().item()
                # print(running_corrects)

            if phase == 'train':
                train_epoch_loss = running_loss / dataset_sizes[phase]
                train_epoch_acc = running_corrects / dataset_sizes[phase]
            else:
                valid_epoch_loss = running_loss / dataset_sizes[phase]
                valid_epoch_acc = running_corrects / dataset_sizes[phase]

            if phase == 'valid' and valid_epoch_acc > best_acc:
                best_acc = valid_epoch_acc
                checkpoint = {
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'amp': amp.state_dict()
                }
                torch.save(checkpoint, 'best_amp_checkpoint.pt')

        print('Epoch [{}/{}] train loss: {:.8f} acc: {:.4f} '
              'valid loss: {:.8f} acc: {:.4f}'.format(epoch, num_epochs - 1,
                                                      train_epoch_loss,
                                                      train_epoch_acc,
                                                      valid_epoch_loss,
                                                      valid_epoch_acc))

    print('Best val Acc: {:4f}'.format(best_acc))

    checkpoint = torch.load('best_amp_checkpoint.pt')

    model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    amp.load_state_dict(checkpoint['amp'])

    del inputs, labels

    return model, optimizer, amp
예제 #27
0
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu

    if args.gpu is not None:
        print(f"Use GPU: {args.gpu} for training!")

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size, rank=args.rank)
    # create model
    if 'alexnet' in args.arch:  # NEW
        if args.pretrained:
            model = AlexNet.from_pretrained(args.arch, args.num_classes)
            print(f"=> using pre-trained model '{args.arch}'")
        else:
            print(f"=> creating model '{args.arch}'")
            model = AlexNet.from_name(args.arch)
    else:
        warnings.warn("Plesase --arch alexnet.")

    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int(args.workers / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.gpu])
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available
        # GPUs
        if args.arch.startswith('alexnet'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda(args.gpu)

    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    model, optimizer = amp.initialize(model, optimizer, opt_level=args.opt_level)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print(f"=> loading checkpoint '{args.resume}'")
            checkpoint = torch.load(args.resume)
            compress_model(checkpoint, filename=args.resume)
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            if args.gpu is not None:
                # best_acc1 may be from a checkpoint from a different GPU
                best_acc1 = best_acc1.to(args.gpu)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            amp.load_state_dict(checkpoint['amp'])
            print(f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']})")
        else:
            print(f"=> no checkpoint found at '{args.resume}'")

    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = datasets.ImageFolder(
        traindir,
        transform=transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True, sampler=train_sampler)

    if 'alexnet' in args.arch:
        image_size = AlexNet.get_image_size(args.arch)
        val_transforms = transforms.Compose([
            transforms.Resize(image_size, interpolation=PIL.Image.BICUBIC),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            normalize,
        ])
        print('Using image size', image_size)
    else:
        val_transforms = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])
        print('Using image size', 224)

    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(valdir, val_transforms),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    if args.evaluate:
        top1 = validate(val_loader, model, criterion, args)
        with open('res.txt', 'w') as f:
            print(f"Acc@1: {top1}", file=f)
        return

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch, args)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, args)

        # evaluate on validation set
        acc1 = validate(val_loader, model, criterion, args)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        if not args.multiprocessing_distributed or (args.multiprocessing_distributed
                                                    and args.rank % ngpus_per_node == 0):
            save_checkpoint({
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer': optimizer.state_dict(),
                'amp': amp.state_dict(),
            }, is_best)
예제 #28
0
파일: train.py 프로젝트: vsocrates/luke
def run_pretraining(args):
    if args.parallel and args.local_rank == -1:
        run_parallel_pretraining(args)
        return

    if args.local_rank == -1:
        if args.cpu:
            print("CPU!!!")
            device = torch.device("cpu")
        else:
            device = torch.device("cuda")
        num_workers = 1
        worker_index = 0
    else:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend="nccl")
        device = torch.device("cuda", args.local_rank)
        num_workers = torch.distributed.get_world_size()
        worker_index = torch.distributed.get_rank()

    if args.local_rank not in (-1, 0):
        logging.getLogger().setLevel(logging.WARN)

    logger.info(
        "Starting pretraining with the following arguments: %s", json.dumps(vars(args), indent=2, sort_keys=True)
    )

    # if args.multilingual:
    #     dataset_dir_list = args.dataset_dir.split(",")
    #     dataset_list = [MedMentionsPretrainingDataset(d) for d in dataset_dir_list]
    # else:
    dataset_list = [MedMentionsPretrainingDataset(args.dataset_dir)]

    bert_config = AutoConfig.from_pretrained(args.bert_model_name)

    dataset_size = sum([len(d) for d in dataset_list])
    num_train_steps_per_epoch = math.ceil(dataset_size / args.batch_size)
    num_train_steps = math.ceil(dataset_size / args.batch_size * args.num_epochs)
    print("The Number of Training Steps is: ", num_train_steps)
    train_batch_size = int(args.batch_size / args.gradient_accumulation_steps / num_workers)

    entity_vocab = dataset_list[0].entity_vocab
    config = LukeConfig(
        entity_vocab_size=entity_vocab.size,
        bert_model_name=args.bert_model_name,
        entity_emb_size=args.entity_emb_size,
        **bert_config.to_dict(),
    )
    model = LukePretrainingModel(config)

    global_step = args.global_step

    batch_generator_args = dict(
        batch_size=train_batch_size,
        masked_lm_prob=args.masked_lm_prob,
        masked_entity_prob=args.masked_entity_prob,
        whole_word_masking=args.whole_word_masking,
        unmasked_word_prob=args.unmasked_word_prob,
        random_word_prob=args.random_word_prob,
        unmasked_entity_prob=args.unmasked_entity_prob,
        random_entity_prob=args.random_entity_prob,
        mask_words_in_entity_span=args.mask_words_in_entity_span,
        num_workers=num_workers,
        worker_index=worker_index,
        skip=global_step * args.batch_size,
    )

    # if args.multilingual:
    #     data_size_list = [len(d) for d in dataset_list]
    #     batch_generator = MultilingualBatchGenerator(
    #         dataset_dir_list, data_size_list, args.sampling_smoothing, **batch_generator_args,
    #     )

    # else:
    batch_generator = LukePretrainingBatchGenerator(args.dataset_dir, **batch_generator_args)

    logger.info("Model configuration: %s", config)

    if args.fix_bert_weights:
        for param in model.parameters():
            param.requires_grad = False
        for param in model.entity_embeddings.parameters():
            param.requires_grad = True
        for param in model.entity_predictions.parameters():
            param.requires_grad = True

    model.to(device)

    param_optimizer = list(model.named_parameters())
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_parameters = [
        {
            "params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
    ]

    if args.original_adam:
        optimizer = AdamW(
            optimizer_parameters,
            lr=args.learning_rate,
            betas=(args.adam_b1, args.adam_b2),
            eps=args.adam_eps,
        )        
    else:
        optimizer = LukeAdamW(
            optimizer_parameters,
            lr=args.learning_rate,
            betas=(args.adam_b1, args.adam_b2),
            eps=args.adam_eps,
            grad_avg_device=torch.device("cpu") if args.grad_avg_on_cpu else device,
        )

    if args.fp16:
        from apex import amp

        if args.fp16_opt_level == "O2":
            model, optimizer = amp.initialize(
                model,
                optimizer,
                opt_level=args.fp16_opt_level,
                master_weights=args.fp16_master_weights,
                min_loss_scale=args.fp16_min_loss_scale,
                max_loss_scale=args.fp16_max_loss_scale,
            )
        else:
            model, optimizer = amp.initialize(
                model,
                optimizer,
                opt_level=args.fp16_opt_level,
                min_loss_scale=args.fp16_min_loss_scale,
                max_loss_scale=args.fp16_max_loss_scale,
            )

    if args.model_file is None:
        bert_model = AutoModelForPreTraining.from_pretrained(args.bert_model_name)
        bert_state_dict = bert_model.state_dict()
        model.load_bert_weights(bert_state_dict)

    else:
        model_state_dict = torch.load(args.model_file, map_location="cpu")
        model.load_state_dict(model_state_dict, strict=False)

    if args.optimizer_file is not None:
        optimizer.load_state_dict(torch.load(args.optimizer_file, map_location="cpu"))

    if args.amp_file is not None:
        amp.load_state_dict(torch.load(args.amp_file, map_location="cpu"))

    if args.lr_schedule == "warmup_constant":
        scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps)
    elif args.lr_schedule == "warmup_linear":
        scheduler = get_linear_schedule_with_warmup(
            optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=num_train_steps
        )
        print(f"Scheduler data: Warmup steps: {args.warmup_steps}; total training steps: {num_train_steps}")
    else:
        raise RuntimeError(f"Invalid scheduler: {args.lr_schedule}")

    if args.scheduler_file is not None:
        scheduler.load_state_dict(torch.load(args.scheduler_file, map_location="cpu"))

    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            broadcast_buffers=False,
            find_unused_parameters=True,
        )

    model.train()

    if args.local_rank == -1 or worker_index == 0:
        entity_vocab.save(os.path.join(args.output_dir, ENTITY_VOCAB_FILE))
        metadata = dict(
            model_config=config.to_dict(),
            max_seq_length=dataset_list[0].max_seq_length,
            max_entity_length=dataset_list[0].max_entity_length,
            max_mention_length=dataset_list[0].max_mention_length,
            arguments=vars(args),
        )
        with open(os.path.join(args.output_dir, "metadata.json"), "w") as metadata_file:
            json.dump(metadata, metadata_file, indent=2, sort_keys=True)

    def save_model(model, suffix):
        if args.local_rank != -1:
            model = model.module

        model_file = f"model_{suffix}.bin"
        torch.save(model.state_dict(), os.path.join(args.output_dir, model_file))
        optimizer_file = f"optimizer_{suffix}.bin"
        torch.save(optimizer.state_dict(), os.path.join(args.output_dir, optimizer_file))
        scheduler_file = f"scheduler_{suffix}.bin"
        torch.save(scheduler.state_dict(), os.path.join(args.output_dir, scheduler_file))
        metadata = dict(
            global_step=global_step, model_file=model_file, optimizer_file=optimizer_file, scheduler_file=scheduler_file
        )
        if args.fp16:
            amp_file = f"amp_{suffix}.bin"
            torch.save(amp.state_dict(), os.path.join(args.output_dir, amp_file))
            metadata["amp_file"] = amp_file
        with open(os.path.join(args.output_dir, f"metadata_{suffix}.json"), "w") as f:
            json.dump(metadata, f, indent=2, sort_keys=True)

    if args.local_rank == -1 or worker_index == 0:
        summary_writer = SummaryWriter(args.log_dir)
        pbar = tqdm(total=num_train_steps, initial=global_step)

    tr_loss = 0
    accumulation_count = 0
    results = []
    prev_error = False
    prev_step_time = time.time()
    prev_save_time = time.time()

    for batch in batch_generator.generate_batches():
        try:
            batch = {k: torch.from_numpy(v).to(device) for k, v in batch.items()}
            result = model(**batch)
            loss = result["loss"]
            result = {k: v.to("cpu").detach().numpy() for k, v in result.items()}

            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            def maybe_no_sync():
                if (
                    hasattr(model, "no_sync")
                    and num_workers > 1
                    and accumulation_count + 1 != args.gradient_accumulation_steps
                ):
                    return model.no_sync()
                else:
                    return contextlib.ExitStack()

            with maybe_no_sync():
                if args.fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

        except RuntimeError:
            if prev_error:
                logger.exception("Consecutive errors have been observed. Exiting...")
                raise
            logger.exception("An unexpected error has occurred. Skipping a batch...")
            prev_error = True
            loss = None
            torch.cuda.empty_cache()
            continue

        accumulation_count += 1
        prev_error = False
        tr_loss += loss.item()
        loss = None
        results.append(result)

        if accumulation_count == args.gradient_accumulation_steps:
            if args.max_grad_norm != 0.0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
            optimizer.step()
            scheduler.step()
            model.zero_grad()
            accumulation_count = 0

            summary = {}
            # line used to be, changed due to backwards compat but it should've worked? 
            # summary["learning_rate"] = max(scheduler.get_lr())
            summary["learning_rate"] = max(scheduler.get_lr())
            summary["loss"] = tr_loss
            tr_loss = 0

            current_time = time.time()
            summary["batch_run_time"] = current_time - prev_step_time
            prev_step_time = current_time

            for name in ("masked_lm", "masked_entity"):
                try:
                    summary[name + "_loss"] = np.concatenate([r[name + "_loss"].flatten() for r in results]).mean()
                    correct = np.concatenate([r[name + "_correct"].flatten() for r in results]).sum()
                    total = np.concatenate([r[name + "_total"].flatten() for r in results]).sum()
                    if total > 0:
                        summary[name + "_acc"] = correct / total
                except KeyError:
                    continue

            results = []

            if args.local_rank == -1 or worker_index == 0:
                for (name, value) in summary.items():
                    summary_writer.add_scalar(name, value, global_step)
                desc = (
                    f"epoch: {int(global_step / num_train_steps_per_epoch)} "
                    f'loss: {summary["loss"]:.4f} '
                    f'time: {datetime.datetime.now().strftime("%H:%M:%S")}'
                )
                pbar.set_description(desc)
                pbar.update()

            global_step += 1

            if args.local_rank == -1 or worker_index == 0:
                if global_step == num_train_steps:
                    # save the final model
                    save_model(model, f"epoch{args.num_epochs}")
                    time.sleep(60)
                elif global_step % num_train_steps_per_epoch == 0:
                    # save the model at each epoch
                    epoch = int(global_step / num_train_steps_per_epoch)
                    save_model(model, f"epoch{epoch}")
                if args.save_interval_sec and time.time() - prev_save_time > args.save_interval_sec:
                    save_model(model, f"step{global_step:07}")
                    prev_save_time = time.time()
                if args.save_interval_steps and global_step % args.save_interval_steps == 0:
                    save_model(model, f"step{global_step}")

            if global_step == num_train_steps:
                break

    if args.local_rank == -1 or worker_index == 0:
        summary_writer.close()
def load_checkpoint(
    path: str, device: torch.device
) -> (TEDD1104, str, torch.optim, float, int, bool, str):
    """
    Restore checkpoint

    Input:
    -path: path of the checkpoint to restore

    Output:
     - model: restored TEDD1104 model
     - optimizer_name: Name of the optimizer used for training: SGD or Adam
     - optimizer: Optimizer used for training
     - acc_dev: Accuracy of the model in the development set
     - epoch: Num of epoch used to train the model
     - fp16: true if the model uses fp16 else false
     - amp_opt_level: If the model uses FP16, the AMP opt_level

    Output:
    """

    checkpoint = torch.load(path)
    dict_hyperparams = checkpoint["hyper_params"]
    model_weights = checkpoint["model"]
    optimizer_name = checkpoint["optimizer_name"]
    optimizer_state = checkpoint["optimizer"]
    acc_dev = checkpoint["acc_dev"]
    epoch = checkpoint["acc_dev"]
    amp_state = checkpoint["amp"]
    opt_level = checkpoint["opt_level"]
    fp16 = dict_hyperparams["fp16"]

    model: TEDD1104 = TEDD1104(
        resnet=dict_hyperparams["resnet"],
        pretrained_resnet=dict_hyperparams["pretrained_resnet"],
        sequence_size=dict_hyperparams["sequence_size"],
        embedded_size=dict_hyperparams["embedded_size"],
        hidden_size=dict_hyperparams["hidden_size"],
        num_layers_lstm=dict_hyperparams["num_layers_lstm"],
        bidirectional_lstm=dict_hyperparams["bidirectional_lstm"],
        layers_out=dict_hyperparams["layers_out"],
        dropout_cnn=dict_hyperparams["dropout_cnn"],
        dropout_cnn_out=dict_hyperparams["dropout_cnn_out"],
        dropout_lstm=dict_hyperparams["dropout_lstm"],
        dropout_lstm_out=dict_hyperparams["dropout_lstm_out"],
    ).to(device=device)

    if optimizer_name == "SGD":
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    elif optimizer_name == "Adam":
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    else:
        raise ValueError(
            f"The optimizer you are trying to load is unknown: "
            f"Optimizer name {optimizer_name}. Available optimizers: SGD, Adam"
        )

    if fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "The model you are trying to load uses FP16 training."
                "Please install apex from https://www.github.com/nvidia/apex")

        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=opt_level)
        amp.load_state_dict(amp_state)

    model.load_state_dict(model_weights)
    optimizer.load_state_dict(optimizer_state)

    return (
        model,
        optimizer_name,
        optimizer,
        acc_dev,
        epoch,
        fp16,
        opt_level,
    )
예제 #30
0
def train_and_fit(args):

    if args.fp16:
        from apex import amp
    else:
        amp = None

    cuda = torch.cuda.is_available()

    train_loader = load_dataloaders(args)
    train_len = len(train_loader)
    logger.info("Loaded %d pre-training samples." % train_len)

    #net = BertModel.from_pretrained('bert-base-uncased', force_download=False)
    net = BertModel.from_pretrained(args.pretrain_model, force_download=False)
    tokenizer = load_pickle("BERT_tokenizer.pkl")
    net.resize_token_embeddings(len(tokenizer))
    if cuda:
        net.cuda()

    logger.info("FREEZING MOST HIDDEN LAYERS...")
    unfrozen_layers = [
        "classifier", "pooler", "encoder.layer.11", "blanks_linear",
        "lm_linear", "cls"
    ]
    for name, param in net.named_parameters():
        if not any([layer in name for layer in unfrozen_layers]):
            print("[FROZE]: %s" % name)
            param.requires_grad = False
        else:
            print("[FREE]: %s" % name)
            param.requires_grad = True

    criterion = Two_Headed_Loss(lm_ignore_idx=tokenizer.pad_token_id)
    optimizer = optim.Adam([{"params": net.parameters(), "lr": args.lr}])

    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[2,4,6,8,12,15,18,20,22,\
                                                                      24,26,30], gamma=0.8)

    start_epoch, best_pred, amp_checkpoint = load_state(net,
                                                        optimizer,
                                                        scheduler,
                                                        args,
                                                        load_best=False)

    if (args.fp16) and (amp is not None):
        logger.info("Using fp16...")
        net, optimizer = amp.initialize(net, optimizer, opt_level='O2')
        if amp_checkpoint is not None:
            amp.load_state_dict(amp_checkpoint)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[2,4,6,8,12,15,18,20,22,\
                                                                          24,26,30], gamma=0.8)

    losses_per_epoch, accuracy_per_epoch = load_results(args.model_no)

    logger.info("Starting training process...")
    pad_id = tokenizer.pad_token_id
    mask_id = tokenizer.mask_token_id
    update_size = len(train_loader) // 10
    for epoch in range(start_epoch, args.num_epochs):
        start_time = time.time()
        net.train()
        total_loss = 0.0
        losses_per_batch = []
        total_acc = 0.0
        lm_accuracy_per_batch = []
        for i, data in enumerate(train_loader, 0):
            x, masked_for_pred, e1_e2_start, Q, blank_labels, _, _, _, _, _ = data
            masked_for_pred = masked_for_pred[(masked_for_pred != pad_id)]
            attention_mask = (x != pad_id).float()
            token_type_ids = torch.zeros((x.shape[0], x.shape[1])).long()

            if cuda:
                x = x.cuda()
                masked_for_pred = masked_for_pred.cuda()
                Q = Q.cuda().float()
                blank_labels = blank_labels.cuda().float()
                attention_mask = attention_mask.cuda()
                token_type_ids = token_type_ids.cuda()

            x = x.long()
            blanks_logits, lm_logits = net(x, token_type_ids=token_type_ids, attention_mask=attention_mask, Q=Q,\
                          e1_e2_start=e1_e2_start)
            lm_logits = lm_logits[(x == mask_id)]

            #return lm_logits, blanks_logits, masked_for_pred, blank_labels, tokenizer # for debugging now

            loss = criterion(lm_logits, blanks_logits, masked_for_pred,
                             blank_labels)
            loss = loss / args.gradient_acc_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            if args.fp16:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    amp.master_params(optimizer), args.max_norm)
            else:
                grad_norm = clip_grad_norm_(net.parameters(), args.max_norm)

            if (i % args.gradient_acc_steps) == 0:
                optimizer.step()
                optimizer.zero_grad()

            total_loss += loss.item()
            total_acc += evaluate_(lm_logits, blanks_logits, masked_for_pred, blank_labels, \
                                   tokenizer, print_=False)[0]

            if (i % update_size) == (update_size - 1):
                losses_per_batch.append(args.gradient_acc_steps * total_loss /
                                        update_size)
                lm_accuracy_per_batch.append(total_acc / update_size)
                print(
                    '[Epoch: %d, %5d/ %d points] total loss, lm accuracy per batch: %.3f, %.3f'
                    % (epoch + 1, (i + 1), train_len, losses_per_batch[-1],
                       lm_accuracy_per_batch[-1]))
                total_loss = 0.0
                total_acc = 0.0

        scheduler.step()
        losses_per_epoch.append(sum(losses_per_batch) / len(losses_per_batch))
        accuracy_per_epoch.append(
            sum(lm_accuracy_per_batch) / len(lm_accuracy_per_batch))
        print("Epoch finished, took %.2f seconds." %
              (time.time() - start_time))
        print("Losses at Epoch %d: %.7f" % (epoch + 1, losses_per_epoch[-1]))
        print("Accuracy at Epoch %d: %.7f" %
              (epoch + 1, accuracy_per_epoch[-1]))

        if accuracy_per_epoch[-1] > best_pred:
            best_pred = accuracy_per_epoch[-1]
            torch.save({
                    'epoch': epoch + 1,\
                    'state_dict': net.state_dict(),\
                    'best_acc': accuracy_per_epoch[-1],\
                    'optimizer' : optimizer.state_dict(),\
                    'scheduler' : scheduler.state_dict(),\
                    'amp': amp.state_dict() if amp is not None else amp
                }, os.path.join("./data/" , "test_model_best_%d.pth.tar" % args.model_no))

        if (epoch % 1) == 0:
            save_as_pickle("test_losses_per_epoch_%d.pkl" % args.model_no,
                           losses_per_epoch)
            save_as_pickle("test_accuracy_per_epoch_%d.pkl" % args.model_no,
                           accuracy_per_epoch)
            torch.save({
                    'epoch': epoch + 1,\
                    'state_dict': net.state_dict(),\
                    'best_acc': accuracy_per_epoch[-1],\
                    'optimizer' : optimizer.state_dict(),\
                    'scheduler' : scheduler.state_dict(),\
                    'amp': amp.state_dict() if amp is not None else amp
                }, os.path.join("./data/" , "test_checkpoint_%d.pth.tar" % args.model_no))

    logger.info("Finished Training!")
    fig = plt.figure(figsize=(20, 20))
    ax = fig.add_subplot(111)
    ax.scatter([e for e in range(len(losses_per_epoch))], losses_per_epoch)
    ax.tick_params(axis="both", length=2, width=1, labelsize=14)
    ax.set_xlabel("Epoch", fontsize=22)
    ax.set_ylabel("Training Loss per batch", fontsize=22)
    ax.set_title("Training Loss vs Epoch", fontsize=32)
    plt.savefig(os.path.join("./data/",
                             "loss_vs_epoch_%d.png" % args.model_no))

    fig2 = plt.figure(figsize=(20, 20))
    ax2 = fig2.add_subplot(111)
    ax2.scatter([e for e in range(len(accuracy_per_epoch))],
                accuracy_per_epoch)
    ax2.tick_params(axis="both", length=2, width=1, labelsize=14)
    ax2.set_xlabel("Epoch", fontsize=22)
    ax2.set_ylabel("Test Masked LM Accuracy", fontsize=22)
    ax2.set_title("Test Masked LM Accuracy vs Epoch", fontsize=32)
    plt.savefig(
        os.path.join("./data/", "accuracy_vs_epoch_%d.png" % args.model_no))

    return net