Esempio n. 1
0
def test_net(args, config, ckpt_path=None, save_path=None, save_name=None):
    if save_path is None:
        logger, test_output_path = create_logger(config.OUTPUT_PATH,
                                                 args.cfg,
                                                 config.DATASET.IMAGE_SET,
                                                 split='test')
        save_path = test_output_path
    if ckpt_path is None:
        _, train_output_path = create_logger(config.OUTPUT_PATH,
                                             args.cfg,
                                             config.DATASET.IMAGE_SET,
                                             split='train')
        model_prefix = os.path.join(train_output_path, config.MODEL_PREFIX)
        ckpt_path = '{}-best.model'.format(model_prefix)
        print('Use best checkpoint {}...'.format(ckpt_path))
    if save_name is None:
        save_name = os.path.split(ckpt_path)[-1]
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    result_csv_path = os.path.join(save_path,
                                   '{}_test_result.csv'.format(save_name))
    if args.repredict or not os.path.isfile(result_csv_path):
        print('test net...')
        pprint.pprint(args)
        pprint.pprint(config)
        device_ids = [int(d) for d in config.GPUS.split(',')]
        # os.environ['CUDA_VISIBLE_DEVICES'] = config.GPUS

        if ckpt_path is None:
            _, train_output_path = create_logger(config.OUTPUT_PATH,
                                                 args.cfg,
                                                 config.DATASET.IMAGE_SET,
                                                 split='train')
            model_prefix = os.path.join(train_output_path, config.MODEL_PREFIX)
            ckpt_path = '{}-best.model'.format(model_prefix)
            print('Use best checkpoint {}...'.format(ckpt_path))

        shutil.copy2(
            ckpt_path,
            os.path.join(save_path,
                         '{}_test_ckpt.model'.format(config.MODEL_PREFIX)))

        # torch.backends.cudnn.enabled = False
        # torch.backends.cudnn.deterministic = True
        # torch.backends.cudnn.benchmark = False

        # get network
        model = eval(config.MODULE)(config)
        if len(device_ids) > 1:
            model = torch.nn.DataParallel(model, device_ids=device_ids).cuda()
        else:
            model = model.cuda()
        if args.fp16:
            [model] = amp.initialize([model],
                                     opt_level='O2',
                                     keep_batchnorm_fp32=False)
        checkpoint = torch.load(ckpt_path,
                                map_location=lambda storage, loc: storage)
        smart_load_model_state_dict(model, checkpoint['state_dict'])

        # loader
        test_loader = make_dataloader(config, mode='test', distributed=False)
        test_dataset = test_loader.dataset
        test_database = test_dataset.database

        # test
        sentence_logits = []
        test_ids = []
        sentence_labels = []
        cur_id = 0
        model.eval()
        for batch in test_loader:
            batch = to_cuda(batch)
            output = model(*batch)
            sentence_logits.append(
                output['sentence_label_logits'].float().detach().cpu().numpy())
            batch_size = batch[0].shape[0]
            sentence_labels.append([
                test_database[cur_id + k]['label'] for k in range(batch_size)
            ])
            test_ids.append([
                test_database[cur_id + k]['pair_id'] for k in range(batch_size)
            ])
            cur_id += batch_size
        sentence_logits = np.concatenate(sentence_logits, axis=0)
        test_ids = np.concatenate(test_ids, axis=0)
        sentence_labels = np.concatenate(sentence_labels, axis=0)
        if config.DATASET.ALIGN_CAPTION_IMG:
            sentence_prediction = np.argmax(sentence_logits,
                                            axis=1).reshape(-1)
        else:
            sentence_prediction = (sentence_logits >
                                   0.).astype(int).reshape(-1)

        # generate final result csv
        dataframe = pd.DataFrame(data=sentence_prediction,
                                 columns=["sentence_pred_label"])
        dataframe['pair_id'] = test_ids
        dataframe['sentence_labels'] = sentence_labels

        # Save predictions
        dataframe = dataframe.set_index('pair_id', drop=True)
        dataframe.to_csv(result_csv_path)
        print('result csv saved to {}.'.format(result_csv_path))
    else:
        print(
            "Cache found in {}, skip test prediction!".format(result_csv_path))
        dataframe = pd.read_csv(result_csv_path)
        sentence_prediction = np.array(dataframe["sentence_pred_label"].values)
        sentence_labels = np.array(dataframe["sentence_labels"].values)

    # Evaluate predictions
    for metric in ["overall_accuracy", "easy_accuracy", "alignment_accuracy"]:
        accuracy = compute_metrics_sentence_level(metric, sentence_prediction,
                                                  sentence_labels)
        print("{} on test set is: {}".format(metric, str(accuracy)))
Esempio n. 2
0
def test_net(args, config, ckpt_path=None, save_path=None, save_name=None):
    if save_path is None:
        logger, test_output_path = create_logger(config.OUTPUT_PATH, args.cfg, config.DATASET.TEST_IMAGE_SET,
                                                 split='test')
        save_path = test_output_path
    if save_name is None:
        save_name = config.MODEL_PREFIX
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    result_csv_path = os.path.join(save_path,
                                   '{}_test_result_{}.csv'.format(save_name, config.DATASET.TASK))
    if args.use_cache and os.path.isfile(result_csv_path):
        print("Cache found in {}, skip test!".format(result_csv_path))
        return result_csv_path

    print('test net...')
    pprint.pprint(args)
    pprint.pprint(config)
    device_ids = [int(d) for d in config.GPUS.split(',')]
    # os.environ['CUDA_VISIBLE_DEVICES'] = config.GPUS

    if ckpt_path is None:
        _, train_output_path = create_logger(config.OUTPUT_PATH, args.cfg, config.DATASET.TRAIN_IMAGE_SET,
                                             split='train')
        model_prefix = os.path.join(train_output_path, config.MODEL_PREFIX)
        ckpt_path = '{}-best.model'.format(model_prefix)
        print('Use best checkpoint {}...'.format(ckpt_path))

    shutil.copy2(ckpt_path, os.path.join(save_path, '{}_test_ckpt_{}.model'.format(config.MODEL_PREFIX, config.DATASET.TASK)))

    # torch.backends.cudnn.enabled = False
    # torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = False

    # get network
    model = eval(config.MODULE)(config)
    if len(device_ids) > 1:
        model = torch.nn.DataParallel(model, device_ids=device_ids).cuda()
    else:
        model = model.cuda()
    if args.fp16:
        [model] = amp.initialize([model],
                                 opt_level='O2',
                                 keep_batchnorm_fp32=False)
    checkpoint = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
    smart_load_model_state_dict(model, checkpoint['state_dict'])

    # loader
    test_loader = make_dataloader(config, mode='test', distributed=False)
    test_dataset = test_loader.dataset
    test_database = test_dataset.database

    # test
    test_probs = []
    test_ids = []
    cur_id = 0
    model.eval()
    for nbatch, batch in zip(trange(len(test_loader)), test_loader):
    # for nbatch, batch in tqdm(enumerate(test_loader)):
        batch = to_cuda(batch)
        if config.DATASET.TASK == 'Q2A':
            output = model(*batch)
            probs = F.softmax(output['label_logits'].float(), dim=1)
            batch_size = probs.shape[0]
            test_probs.append(probs.float().detach().cpu().numpy())
            test_ids.append([test_database[cur_id + k]['annot_id'] for k in range(batch_size)])
            cur_id += batch_size
        elif config.DATASET.TASK == 'QA2R':
            conditioned_probs = []
            for a_id in range(4):
                q_index_in_batch = test_loader.dataset.data_names.index('question')
                q_align_mat_index_in_batch = test_loader.dataset.data_names.index('question_align_matrix')
                batch_ = [*batch]
                batch_[q_index_in_batch] = batch[q_index_in_batch][:, a_id, :, :]
                batch_[q_align_mat_index_in_batch] = batch[q_align_mat_index_in_batch][:, a_id, :, :]
                output = model(*batch_)
                probs = F.softmax(output['label_logits'].float(), dim=1)
                conditioned_probs.append(probs.float().detach().cpu().numpy())
            conditioned_probs = np.concatenate(conditioned_probs, axis=1)
            test_probs.append(conditioned_probs)
            test_ids.append([test_database[cur_id + k]['annot_id'] for k in range(conditioned_probs.shape[0])])
            cur_id += conditioned_probs.shape[0]
        else:
            raise ValueError('Not Support Task {}'.format(config.DATASET.TASK))
    test_probs = np.concatenate(test_probs, axis=0)
    test_ids = np.concatenate(test_ids, axis=0)

    result_npy_path = os.path.join(save_path, '{}_test_result_{}.npy'.format(save_name, config.DATASET.TASK))
    np.save(result_npy_path, test_probs)
    print('result npy saved to {}.'.format(result_npy_path))

    # generate final result csv
    if config.DATASET.TASK == 'Q2A':
        columns = ['answer_{}'.format(i) for i in range(4)]
    else:
        columns = ['rationale_conditioned_on_a{}_{}'.format(i, j) for i in range(4) for j in range(4)]
    dataframe = pd.DataFrame(data=test_probs, columns=columns)
    dataframe['annot_id'] = test_ids
    dataframe = dataframe.set_index('annot_id', drop=True)

    dataframe.to_csv(result_csv_path)
    print('result csv saved to {}.'.format(result_csv_path))
    return result_csv_path
Esempio n. 3
0
def test_net(args, config, ckpt_path=None, save_path=None, save_name=None):
    print('test net...')
    pprint.pprint(args)
    pprint.pprint(config)
    device_ids = [int(d) for d in config.GPUS.split(',')]

    obj_cats = config.OBJECT_CATEGORIES
    pred_cats = config.PREDICATE_CATEGORIES

    torch.backends.cudnn.enabled = False
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    if ckpt_path is None:
        _, train_output_path = create_logger(config.OUTPUT_PATH, args.cfg, 
                                             config.DATASET.TRAIN_IMAGE_SET,
                                             split='train')
        model_prefix = os.path.join(train_output_path, config.MODEL_PREFIX)
        ckpt_path = '{}-best.model'.format(model_prefix)
        print('Use best checkpoint {}...'.format(ckpt_path))
    if save_path is None:
        logger, test_output_path = create_logger(config.OUTPUT_PATH, args.cfg, 
                                                 config.DATASET.TEST_IMAGE_SET,
                                                 split='test')
        save_path = test_output_path
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    shutil.copy2(ckpt_path,
                 os.path.join(save_path, '{}_test_ckpt_{}.model'.format(
                    config.MODEL_PREFIX, config.DATASET.TASK
                    )))

    # get network
    model = eval(config.MODULE)(config)
    if len(device_ids) > 1:
        model = torch.nn.DataParallel(model, device_ids=device_ids).cuda()
    else:
        torch.cuda.set_device(device_ids[0])
        model = model.cuda()
    checkpoint = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
    smart_load_model_state_dict(model, checkpoint['state_dict'])

    split = args.split
    loader = make_dataloader(config, mode=split, distributed=False)

    nb_of_correct_50 = nb_of_sample = nb_of_correct_top100 = 0
    model.eval()

    save_dir = ''
    if args.visualize_mask: # For mask visualization purpose
        save_dir = 'heatmap/vrd'
        if not os.path.isdir(save_dir):
            os.makedirs(save_dir)

    for nbatch, batch in zip(trange(len(loader)), loader):
        batch = to_cuda(batch)
        output = model(*batch)

        n_correct, n_sample, n_correct_top100 = compute_recall(output, obj_cats, pred_cats, remove_bg=config.TRAIN.SAMPLE_RELS != -1, visualize_mask=args.visualize_mask, save_dir=save_dir)
        nb_of_correct_50 += n_correct
        nb_of_correct_top100 += n_correct_top100
        nb_of_sample += n_sample
        
    recall_50 = nb_of_correct_50 / nb_of_sample
    recall_100 = nb_of_correct_top100 / nb_of_sample

    return recall_50, recall_100
Esempio n. 4
0
def train_net(args, config):
    # setup logger
    logger, final_output_path = create_logger(config.OUTPUT_PATH,
                                              args.cfg,
                                              config.DATASET.TRAIN_IMAGE_SET,
                                              split='train')
    model_prefix = os.path.join(final_output_path, config.MODEL_PREFIX)
    if args.log_dir is None:
        args.log_dir = os.path.join(final_output_path, 'tensorboard_logs')

    pprint.pprint(args)
    logger.info('training args:{}\n'.format(args))
    pprint.pprint(config)
    logger.info('training config:{}\n'.format(pprint.pformat(config)))

    # manually set random seed
    if config.RNG_SEED > -1:
        np.random.seed(config.RNG_SEED)
        torch.random.manual_seed(config.RNG_SEED)
        torch.cuda.manual_seed_all(config.RNG_SEED)

    # cudnn
    torch.backends.cudnn.benchmark = False
    if args.cudnn_off:
        torch.backends.cudnn.enabled = False

    if args.dist:
        model = eval(config.MODULE)(config)
        local_rank = int(os.environ.get('LOCAL_RANK') or 0)
        config.GPUS = str(local_rank)
        torch.cuda.set_device(local_rank)
        master_address = os.environ['MASTER_ADDR']
        master_port = int(os.environ['MASTER_PORT'] or 23456)
        world_size = int(os.environ['WORLD_SIZE'] or 1)
        rank = int(os.environ['RANK'] or 0)
        if args.slurm:
            distributed.init_process_group(backend='nccl')
        else:
            distributed.init_process_group(backend='nccl',
                                           init_method='tcp://{}:{}'.format(
                                               master_address, master_port),
                                           world_size=world_size,
                                           rank=rank,
                                           group_name='mtorch')
        print(
            f'native distributed, size: {world_size}, rank: {rank}, local rank: {local_rank}'
        )
        torch.cuda.set_device(local_rank)
        config.GPUS = str(local_rank)
        model = model.cuda()
        if not config.TRAIN.FP16:
            model = DDP(model,
                        device_ids=[local_rank],
                        output_device=local_rank)

        if rank == 0:
            summary_parameters(
                model.module if isinstance(
                    model, torch.nn.parallel.DistributedDataParallel) else
                model, logger)
            shutil.copy(args.cfg, final_output_path)
            shutil.copy(inspect.getfile(eval(config.MODULE)),
                        final_output_path)

        writer = None
        if args.log_dir is not None:
            tb_log_dir = os.path.join(args.log_dir, 'rank{}'.format(rank))
            if not os.path.exists(tb_log_dir):
                os.makedirs(tb_log_dir)
            writer = SummaryWriter(log_dir=tb_log_dir)

        train_loader, train_sampler = make_dataloader(config,
                                                      mode='train',
                                                      distributed=True,
                                                      num_replicas=world_size,
                                                      rank=rank,
                                                      expose_sampler=True)
        val_loader = make_dataloader(config,
                                     mode='val',
                                     distributed=True,
                                     num_replicas=world_size,
                                     rank=rank)

        batch_size = world_size * (sum(config.TRAIN.BATCH_IMAGES) if
                                   isinstance(config.TRAIN.BATCH_IMAGES, list)
                                   else config.TRAIN.BATCH_IMAGES)
        if config.TRAIN.GRAD_ACCUMULATE_STEPS > 1:
            batch_size = batch_size * config.TRAIN.GRAD_ACCUMULATE_STEPS
        base_lr = config.TRAIN.LR * batch_size
        optimizer_grouped_parameters = [{
            'params': [p for n, p in model.named_parameters() if _k in n],
            'lr':
            base_lr * _lr_mult
        } for _k, _lr_mult in config.TRAIN.LR_MULT]
        optimizer_grouped_parameters.append({
            'params': [
                p for n, p in model.named_parameters()
                if all([_k not in n for _k, _ in config.TRAIN.LR_MULT])
            ]
        })
        if config.TRAIN.OPTIMIZER == 'SGD':
            optimizer = optim.SGD(optimizer_grouped_parameters,
                                  lr=config.TRAIN.LR * batch_size,
                                  momentum=config.TRAIN.MOMENTUM,
                                  weight_decay=config.TRAIN.WD)
        elif config.TRAIN.OPTIMIZER == 'Adam':
            optimizer = optim.Adam(optimizer_grouped_parameters,
                                   lr=config.TRAIN.LR * batch_size,
                                   weight_decay=config.TRAIN.WD)
        elif config.TRAIN.OPTIMIZER == 'AdamW':
            optimizer = AdamW(optimizer_grouped_parameters,
                              lr=config.TRAIN.LR * batch_size,
                              betas=(0.9, 0.999),
                              eps=1e-6,
                              weight_decay=config.TRAIN.WD,
                              correct_bias=True)
        else:
            raise ValueError('Not support optimizer {}!'.format(
                config.TRAIN.OPTIMIZER))
        total_gpus = world_size

    else:
        #os.environ['CUDA_VISIBLE_DEVICES'] = config.GPUS
        model = eval(config.MODULE)(config)

        # import pdb; pdb.set_trace()
        if config.NETWORK.VLBERT.vlbert_frozen:
            # freeze all parameters first
            for p in model.parameters():
                p.requires_grad = False

            # unfreeze the last layer(s)
            if config.NETWORK.VLBERT.vlbert_unfrozen_layers != 0:
                for p in model.vlbert.encoder.layer[
                        -config.NETWORK.VLBERT.
                        vlbert_unfrozen_layers:].parameters():
                    p.requires_grad = True

            for p in model.final_mlp.parameters():
                p.requires_grad = True

            if config.NETWORK.USE_SPATIAL_MODEL:
                for p in model.simple_spatial_model.parameters():
                    p.requires_grad = True
                for p in model.spa_fusion_linear.parameters():
                    p.requires_grad = True
                for p in model.spa_linear.parameters():
                    p.requires_grad = True
                if config.NETWORK.SPA_ONE_MORE_LAYER:
                    for p in model.spa_linear_hidden.parameters():
                        p.requires_grad = True

            # If use enhanced image feature
            if config.NETWORK.VLBERT.ENHANCED_IMG_FEATURE:
                for p in model.vlbert.obj_feat_downsample.parameters():
                    p.requires_grad = True
                for p in model.vlbert.obj_feat_batchnorm.parameters():
                    p.requires_grad = True
                for p in model.vlbert.lan_img_conv1.parameters():
                    p.requires_grad = True
                for p in model.vlbert.lan_img_conv2.parameters():
                    p.requires_grad = True
                for p in model.vlbert.lan_img_conv3.parameters():
                    p.requires_grad = True
                for p in model.vlbert.lan_img_conv4.parameters():
                    p.requires_grad = True

        if config.NETWORK.VLBERT.vlbert_frozen_embedding_LayerNorm:
            print('freezing embedding_LayerNorm...')
            for p in model.vlbert.embedding_LayerNorm.parameters():
                p.requires_grad = False
        if config.NETWORK.VLBERT.vlbert_frozen_encoder:
            print('freezing encoder...')
            for p in model.vlbert.encoder.parameters():
                p.requires_grad = False

        summary_parameters(model, logger)
        shutil.copy(args.cfg, final_output_path)
        shutil.copy(inspect.getfile(eval(config.MODULE)), final_output_path)
        num_gpus = len(config.GPUS.split(','))
        assert num_gpus <= 1 or (not config.TRAIN.FP16), "Not support fp16 with torch.nn.DataParallel. " \
                                                         "Please use amp.parallel.DistributedDataParallel instead."
        total_gpus = num_gpus
        rank = None
        writer = SummaryWriter(
            log_dir=args.log_dir) if args.log_dir is not None else None

        # model
        if num_gpus > 1:
            model = torch.nn.DataParallel(
                model,
                device_ids=[int(d) for d in config.GPUS.split(',')]).cuda()
        else:
            torch.cuda.set_device(int(config.GPUS))
            model.cuda()

        # loader
        train_loader = make_dataloader(config,
                                       mode=config.DATASET.TRAIN_IMAGE_SET,
                                       distributed=False)
        val_loader = make_dataloader(config,
                                     mode=config.DATASET.VAL_IMAGE_SET,
                                     distributed=False)
        test_loader = make_dataloader(config,
                                      mode=config.DATASET.TEST_IMAGE_SET,
                                      distributed=False)
        train_sampler = None

        batch_size = num_gpus * (sum(config.TRAIN.BATCH_IMAGES) if isinstance(
            config.TRAIN.BATCH_IMAGES, list) else config.TRAIN.BATCH_IMAGES)
        if config.TRAIN.GRAD_ACCUMULATE_STEPS > 1:
            batch_size = batch_size * config.TRAIN.GRAD_ACCUMULATE_STEPS
        base_lr = config.TRAIN.LR * batch_size
        optimizer_grouped_parameters = [{
            'params': [p for n, p in model.named_parameters() if _k in n],
            'lr':
            base_lr * _lr_mult
        } for _k, _lr_mult in config.TRAIN.LR_MULT]
        optimizer_grouped_parameters.append({
            'params': [
                p for n, p in model.named_parameters()
                if all([_k not in n for _k, _ in config.TRAIN.LR_MULT])
            ]
        })

        if config.TRAIN.OPTIMIZER == 'SGD':
            optimizer = optim.SGD(optimizer_grouped_parameters,
                                  lr=config.TRAIN.LR * batch_size,
                                  momentum=config.TRAIN.MOMENTUM,
                                  weight_decay=config.TRAIN.WD)
        elif config.TRAIN.OPTIMIZER == 'Adam':
            optimizer = optim.Adam(optimizer_grouped_parameters,
                                   lr=config.TRAIN.LR * batch_size,
                                   weight_decay=config.TRAIN.WD)
        elif config.TRAIN.OPTIMIZER == 'AdamW':
            optimizer = AdamW(optimizer_grouped_parameters,
                              lr=config.TRAIN.LR * batch_size,
                              betas=(0.9, 0.999),
                              eps=1e-6,
                              weight_decay=config.TRAIN.WD,
                              correct_bias=True)
        else:
            raise ValueError('Not support optimizer {}!'.format(
                config.TRAIN.OPTIMIZER))

    # partial load pretrain state dict
    if config.NETWORK.PARTIAL_PRETRAIN != "":
        pretrain_state_dict = torch.load(
            config.NETWORK.PARTIAL_PRETRAIN,
            map_location=lambda storage, loc: storage)['state_dict']
        prefix_change = [
            prefix_change.split('->')
            for prefix_change in config.NETWORK.PARTIAL_PRETRAIN_PREFIX_CHANGES
        ]

        if len(prefix_change) > 0:
            pretrain_state_dict_parsed = {}
            for k, v in pretrain_state_dict.items():
                no_match = True
                for pretrain_prefix, new_prefix in prefix_change:
                    if k.startswith(pretrain_prefix):
                        k = new_prefix + k[len(pretrain_prefix):]
                        pretrain_state_dict_parsed[k] = v
                        no_match = False
                        break
                if no_match:
                    pretrain_state_dict_parsed[k] = v
            pretrain_state_dict = pretrain_state_dict_parsed
        # import pdb; pdb.set_trace()
        smart_partial_load_model_state_dict(model, pretrain_state_dict)

    # pretrained classifier
    if config.NETWORK.CLASSIFIER_PRETRAINED:  # false for now
        print(
            'Initializing classifier weight from pretrained word embeddings...'
        )

        for k, v in model.state_dict().items():
            if 'word_embeddings.weight' in k:
                word_embeddings = v.detach().clone()
                break

        answers_word_embed = []
        for answer in config.PREDICATE_CATEGORIES:
            a_tokens = train_loader.dataset.tokenizer.tokenize(answer)
            a_ids = train_loader.dataset.tokenizer.convert_tokens_to_ids(
                a_tokens)
            a_word_embed = (torch.stack(
                [word_embeddings[a_id] for a_id in a_ids], dim=0)).mean(dim=0)
            answers_word_embed.append(a_word_embed)
        answers_word_embed_tensor = torch.stack(answers_word_embed, dim=0)
        for name, module in model.named_modules():
            if name.endswith('final_mlp'):
                module[-1].weight.data = answers_word_embed_tensor.to(
                    device=module[-1].weight.data.device)

    # metrics
    train_metrics_list = [
        spasen_metrics.Accuracy(allreduce=args.dist,
                                num_replicas=world_size if args.dist else 1)
    ]
    val_metrics_list = [
        spasen_metrics.Accuracy(allreduce=args.dist,
                                num_replicas=world_size if args.dist else 1)
    ]
    for output_name, display_name in config.TRAIN.LOSS_LOGGERS:
        train_metrics_list.append(
            spasen_metrics.LossLogger(
                output_name,
                display_name=display_name,
                allreduce=args.dist,
                num_replicas=world_size if args.dist else 1))
        val_metrics_list.append(
            spasen_metrics.LossLogger(
                output_name,
                display_name=display_name,
                allreduce=args.dist,
                num_replicas=world_size if args.dist else 1))

    train_metrics = CompositeEvalMetric()
    val_metrics = CompositeEvalMetric()
    for child_metric in train_metrics_list:
        train_metrics.add(child_metric)
    for child_metric in val_metrics_list:
        val_metrics.add(child_metric)

    # epoch end callbacks
    epoch_end_callbacks = []
    if (rank is None) or (rank == 0):
        epoch_end_callbacks = [
            Checkpoint(model_prefix, config.CHECKPOINT_FREQUENT)
        ]
    validation_monitor = ValidationMonitor(
        do_validation,
        val_loader,
        val_metrics,
        host_metric_name='Acc',
        label_index_in_batch=config.DATASET.LABEL_INDEX_IN_BATCH)
    testing_monitor = ValidationMonitor(
        do_validation,
        test_loader,
        val_metrics,
        host_metric_name='Acc',
        label_index_in_batch=config.DATASET.LABEL_INDEX_IN_BATCH,
        do_test=True)

    # optimizer initial lr before
    for group in optimizer.param_groups:
        group.setdefault('initial_lr', group['lr'])

    # resume/auto-resume
    if rank is None or rank == 0:
        smart_resume(model, optimizer, validation_monitor, config,
                     model_prefix, logger)
    if args.dist:
        begin_epoch = torch.tensor(config.TRAIN.BEGIN_EPOCH).cuda()
        distributed.broadcast(begin_epoch, src=0)
        config.TRAIN.BEGIN_EPOCH = begin_epoch.item()

    # batch end callbacks
    batch_size = len(config.GPUS.split(',')) * config.TRAIN.BATCH_IMAGES
    batch_end_callbacks = [
        Speedometer(batch_size,
                    config.LOG_FREQUENT,
                    batches_per_epoch=len(train_loader),
                    epochs=config.TRAIN.END_EPOCH - config.TRAIN.BEGIN_EPOCH)
    ]

    # setup lr step and lr scheduler
    if config.TRAIN.LR_SCHEDULE == 'plateau':
        print("Warning: not support resuming on plateau lr schedule!")
        lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='max',
            factor=config.TRAIN.LR_FACTOR,
            patience=1,
            verbose=True,
            threshold=1e-4,
            threshold_mode='rel',
            cooldown=2,
            min_lr=0,
            eps=1e-8)
    elif config.TRAIN.LR_SCHEDULE == 'triangle':
        lr_scheduler = WarmupLinearSchedule(
            optimizer,
            config.TRAIN.WARMUP_STEPS if config.TRAIN.WARMUP else 0,
            t_total=int(config.TRAIN.END_EPOCH * len(train_loader) /
                        config.TRAIN.GRAD_ACCUMULATE_STEPS),
            last_epoch=int(config.TRAIN.BEGIN_EPOCH * len(train_loader) /
                           config.TRAIN.GRAD_ACCUMULATE_STEPS) - 1)
    elif config.TRAIN.LR_SCHEDULE == 'step':
        lr_iters = [
            int(epoch * len(train_loader) / config.TRAIN.GRAD_ACCUMULATE_STEPS)
            for epoch in config.TRAIN.LR_STEP
        ]
        lr_scheduler = WarmupMultiStepLR(
            optimizer,
            milestones=lr_iters,
            gamma=config.TRAIN.LR_FACTOR,
            warmup_factor=config.TRAIN.WARMUP_FACTOR,
            warmup_iters=config.TRAIN.WARMUP_STEPS
            if config.TRAIN.WARMUP else 0,
            warmup_method=config.TRAIN.WARMUP_METHOD,
            last_epoch=int(config.TRAIN.BEGIN_EPOCH * len(train_loader) /
                           config.TRAIN.GRAD_ACCUMULATE_STEPS) - 1)
    else:
        raise ValueError("Not support lr schedule: {}.".format(
            config.TRAIN.LR_SCHEDULE))

    # broadcast parameter and optimizer state from rank 0 before training start
    if args.dist:
        for v in model.state_dict().values():
            distributed.broadcast(v, src=0)
        # for v in optimizer.state_dict().values():
        #     distributed.broadcast(v, src=0)
        best_epoch = torch.tensor(validation_monitor.best_epoch).cuda()
        best_val = torch.tensor(validation_monitor.best_val).cuda()
        distributed.broadcast(best_epoch, src=0)
        distributed.broadcast(best_val, src=0)
        validation_monitor.best_epoch = best_epoch.item()
        validation_monitor.best_val = best_val.item()

    # apex: amp fp16 mixed-precision training
    if config.TRAIN.FP16:
        # model.apply(bn_fp16_half_eval)
        model, optimizer = amp.initialize(
            model,
            optimizer,
            opt_level='O2',
            keep_batchnorm_fp32=False,
            loss_scale=config.TRAIN.FP16_LOSS_SCALE,
            min_loss_scale=32.0)
        if args.dist:
            model = Apex_DDP(model, delay_allreduce=True)

    train(model,
          optimizer,
          lr_scheduler,
          train_loader,
          train_sampler,
          train_metrics,
          config.TRAIN.BEGIN_EPOCH,
          config.TRAIN.END_EPOCH,
          logger,
          rank=rank,
          batch_end_callbacks=batch_end_callbacks,
          epoch_end_callbacks=epoch_end_callbacks,
          writer=writer,
          validation_monitor=validation_monitor,
          fp16=config.TRAIN.FP16,
          clip_grad_norm=config.TRAIN.CLIP_GRAD_NORM,
          gradient_accumulate_steps=config.TRAIN.GRAD_ACCUMULATE_STEPS,
          testing_monitor=testing_monitor)

    return rank, model
Esempio n. 5
0
def test_net(args, config, ckpt_path=None, save_path=None, save_name=None):
    print('test net...')
    pprint.pprint(args)
    pprint.pprint(config)
    device_ids = [int(d) for d in config.GPUS.split(',')]
    # os.environ['CUDA_VISIBLE_DEVICES'] = config.GPUS

    torch.backends.cudnn.enabled = False
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    if ckpt_path is None:
        _, train_output_path = create_logger(config.OUTPUT_PATH, args.cfg, config.DATASET.TRAIN_IMAGE_SET,
                                             split='train')
        model_prefix = os.path.join(train_output_path, config.MODEL_PREFIX)
        ckpt_path = '{}-best.model'.format(model_prefix)
        print('Use best checkpoint {}...'.format(ckpt_path))
    if save_path is None:
        logger, test_output_path = create_logger(config.OUTPUT_PATH, args.cfg, config.DATASET.TEST_IMAGE_SET,
                                                 split='test')
        save_path = test_output_path
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    
    test_ckpt_path = '{}_test_ckpt_{}.model'.format(config.MODEL_PREFIX, config.DATASET.TASK)
    try:
        shutil.copy2(ckpt_path,
                    os.path.join(save_path, test_ckpt_path))
    except shutil.SameFileError:
        print(f'Test checkpoints is alredy exist: {test_ckpt_path}')

    # get network
    model = eval(config.MODULE)(config)

    if hasattr(model, 'setup_adapter'):
        model.setup_adapter()

    # if len(device_ids) > 1:
    #     model = torch.nn.DataParallel(model, device_ids=device_ids).cuda()
    # else:
    torch.cuda.set_device(min(device_ids))
    model = model.cuda()
    
    checkpoint = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
    smart_load_model_state_dict(model, checkpoint['state_dict'])

    # loader
    test_loader = make_dataloader(config, mode='test', distributed=False)
    test_dataset = test_loader.dataset
    test_database = test_dataset.database

    # test
    predicts = []
    model.eval()
    cur_id = 0
    for nbatch, batch in zip(trange(len(test_loader)), test_loader):
    # for nbatch, batch in tqdm(enumerate(test_loader)):
        bs = test_loader.batch_sampler.batch_size if test_loader.batch_sampler is not None else test_loader.batch_size
        batch = to_cuda(batch)
        outputs = model(*batch[:-1])
        if outputs['label_logits'].shape[-1] == 1:
            prob = torch.sigmoid(outputs['label_logits'][:, 0]).detach().cpu().tolist()
        else:
            prob = torch.softmax(outputs['label_logits'], dim=-1)[:, 1].detach().cpu().tolist()
        sample_ids = batch[-1].cpu().tolist()
        for pb, id in zip(prob, sample_ids):
            predicts.append({
                'id': int(id),
                'proba': float(pb),
                'label': int(pb > 0.5)
            })

    cfg_name = os.path.splitext(os.path.basename(args.cfg))[0]
    output_name = cfg_name if save_name is None else save_name
    result_json_path = os.path.join(save_path, f'{output_name}_cls_{config.DATASET.TEST_IMAGE_SET}.json')
    result_csv_path = os.path.join(save_path, f'{output_name}_cls_{config.DATASET.TEST_IMAGE_SET}.csv')
    
    with open(result_json_path, 'w') as f:
        json.dump(predicts, f)
    print('result json saved to {}.'.format(result_json_path))

    pd.DataFrame.from_dict(predicts).to_csv(result_csv_path, index=False)
    return result_json_path
Esempio n. 6
0
def test_net2018(args, config, ckpt_path=None, save_path=None, save_name=None):
    print('test net...')
    pprint.pprint(args)
    pprint.pprint(config)
    device_ids = [int(d) for d in config.GPUS.split(',')]
    # os.environ['CUDA_VISIBLE_DEVICES'] = config.GPUS

    torch.backends.cudnn.enabled = False
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    if ckpt_path is None:
        _, train_output_path = create_logger(config.OUTPUT_PATH,
                                             args.cfg,
                                             config.DATASET.TRAIN_IMAGE_SET,
                                             split='train')
        model_prefix = os.path.join(train_output_path, config.MODEL_PREFIX)
        ckpt_path = '{}-best.model'.format(model_prefix)
        print('Use best checkpoint {}...'.format(ckpt_path))
    if save_path is None:
        logger, test_output_path = create_logger(config.OUTPUT_PATH,
                                                 args.cfg,
                                                 config.DATASET.TEST_IMAGE_SET,
                                                 split='test')
        save_path = test_output_path
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    shutil.copy2(
        ckpt_path,
        os.path.join(
            save_path, '{}_test_ckpt_{}.model'.format(config.MODEL_PREFIX,
                                                      config.DATASET.TASK)))

    # ************
    # Step 1: Select model architecture and preload trained model
    model = eval(config.MODULE)(config)
    if len(device_ids) > 1:
        model = torch.nn.DataParallel(model, device_ids=device_ids).cuda()
    else:
        torch.cuda.set_device(device_ids[0])
        model = model.cuda()
    checkpoint = torch.load(ckpt_path,
                            map_location=lambda storage, loc: storage)
    smart_load_model_state_dict(model, checkpoint['state_dict'])

    # ************
    # Step 2: Create dataloader to include all caption-image pairs
    test_loader = make_dataloader(config, mode='test', distributed=False)
    test_dataset = test_loader.dataset
    test_database = test_dataset.database
    vocab = test_dataset.MLT_vocab

    # ************
    # Step 3: Run all pairs through model for inference
    word_de_ids = []
    words_de = []
    words_en = []
    captions_en = []
    captions_de = []
    logit_words = []
    logits = []
    model.eval()
    cur_id = 0
    for nbatch, batch in zip(trange(len(test_loader)), test_loader):
        bs = test_loader.batch_sampler.batch_size if test_loader.batch_sampler is not None else test_loader.batch_size
        # for id in range(cur_id, min(cur_id + bs, len(test_database))):
        #     print(test_database[id])
        words_de.extend([
            test_database[id]['word_de']
            for id in range(cur_id, min(cur_id + bs, len(test_database)))
        ])
        words_en.extend([
            test_database[id]['word_en']
            for id in range(cur_id, min(cur_id + bs, len(test_database)))
        ])
        captions_en.extend([
            test_database[id]['caption_en']
            for id in range(cur_id, min(cur_id + bs, len(test_database)))
        ])
        captions_de.extend([
            test_database[id]['caption_de']
            for id in range(cur_id, min(cur_id + bs, len(test_database)))
        ])
        batch = to_cuda(batch)
        output = model(*batch)
        # FM note: output is tuple (outputs, loss)
        probs = F.softmax(output[0]['MLT_logits'].float(), dim=1)
        batch_size = probs.shape[0]
        logits.extend(probs.argmax(dim=1).detach().cpu().tolist())
        # word_de_ids.extend(output[0]['MLT_label'].detach().cpu().tolist())
        logit_words.extend([
            vocab[id]
            for id in logits[cur_id:min(cur_id + bs, len(test_database))]
        ])

        cur_id += bs

        #     output = model(*batch)
        #     probs = F.softmax(output['label_logits'].float(), dim=1)
        #     batch_size = probs.shape[0]
        #     test_probs.append(probs.float().detach().cpu().numpy())
        #     test_ids.append([test_database[cur_id + k]['annot_id'] for k in range(batch_size)])
        # logits.extend(F.sigmoid(output[0]['relationship_logits']).detach().cpu().tolist())

    # ************
    # Step 3: Store all logit results in file for later evalution
    result = [{
        'logit': l_id,
        'word_en': word_en,
        'word_de': word_de,
        'word_pred': logit_word,
        'caption_en': caption_en,
        'caption_de': caption_de
    } for l_id, word_en, word_de, logit_word, caption_en, caption_de in zip(
        logits, words_en, words_de, logit_words, captions_en, captions_de)]
    cfg_name = os.path.splitext(os.path.basename(args.cfg))[0]
    result_json_path = os.path.join(
        save_path,
        '{}_MLT_{}.json'.format(cfg_name if save_name is None else save_name,
                                config.DATASET.TEST_IMAGE_SET))
    with open(result_json_path, 'w') as f:
        json.dump(result, f)
    print('result json saved to {}.'.format(result_json_path))
    return result_json_path
Esempio n. 7
0
def train_net(args, config):
    # setup logger
    logger, final_output_path = create_logger(config.OUTPUT_PATH,
                                              args.cfg,
                                              config.DATASET[0].TRAIN_IMAGE_SET if isinstance(config.DATASET, list)
                                              else config.DATASET.TRAIN_IMAGE_SET,
                                              split='train')
    model_prefix = os.path.join(final_output_path, config.MODEL_PREFIX)
    if args.log_dir is None:
        args.log_dir = os.path.join(final_output_path, 'tensorboard_logs')

    pprint.pprint(args)
    logger.info('training args:{}\n'.format(args))
    pprint.pprint(config)
    logger.info('training config:{}\n'.format(pprint.pformat(config)))

    # manually set random seed
    if config.RNG_SEED > -1:
        random.seed(config.RNG_SEED)
        np.random.seed(config.RNG_SEED)
        torch.random.manual_seed(config.RNG_SEED)
        torch.cuda.manual_seed_all(config.RNG_SEED)

    # cudnn
    torch.backends.cudnn.benchmark = False
    if args.cudnn_off:
        torch.backends.cudnn.enabled = False

    if args.dist:
        model = eval(config.MODULE)(config)
        local_rank = int(os.environ.get('LOCAL_RANK') or 0)
        config.GPUS = str(local_rank)
        torch.cuda.set_device(local_rank)
        master_address = os.environ['MASTER_ADDR']
        # master_port = int(os.environ['MASTER_PORT'] or 23456)
        # master_port = int(9997)
        master_port = int(9995)
        world_size = int(os.environ['WORLD_SIZE'] or 1)
        rank = int(os.environ['RANK'] or 0)
        if args.slurm:
            distributed.init_process_group(backend='nccl')
        else:
            distributed.init_process_group(
                backend='nccl',
                init_method='tcp://{}:{}'.format(master_address, master_port),
                world_size=world_size,
                rank=rank,
                group_name='mtorch')
        print(f'native distributed, size: {world_size}, rank: {rank}, local rank: {local_rank}')
        torch.cuda.set_device(local_rank)
        config.GPUS = str(local_rank)
        model = model.cuda()
        if not config.TRAIN.FP16:
            model = DDP(model, device_ids=[local_rank], output_device=local_rank)

        if rank == 0:
            summary_parameters(model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model,
                               logger)
            shutil.copy(args.cfg, final_output_path)
            shutil.copy(inspect.getfile(eval(config.MODULE)), final_output_path)

        writer = None
        if args.log_dir is not None:
            tb_log_dir = os.path.join(args.log_dir, 'rank{}'.format(rank))
            if not os.path.exists(tb_log_dir):
                os.makedirs(tb_log_dir)
            writer = SummaryWriter(log_dir=tb_log_dir)

        if isinstance(config.DATASET, list):
            train_loaders_and_samplers = make_dataloaders(config,
                                                          mode='train',
                                                          distributed=True,
                                                          num_replicas=world_size,
                                                          rank=rank,
                                                          expose_sampler=True)
            val_loaders = make_dataloaders(config,
                                           mode='val',
                                           distributed=True,
                                           num_replicas=world_size,
                                           rank=rank)
            train_loader = MultiTaskDataLoader([loader for loader, _ in train_loaders_and_samplers])
            val_loader = MultiTaskDataLoader(val_loaders)
            train_sampler = train_loaders_and_samplers[0][1]
        else:
            train_loader, train_sampler = make_dataloader(config,
                                                          mode='train',
                                                          distributed=True,
                                                          num_replicas=world_size,
                                                          rank=rank,
                                                          expose_sampler=True)
            val_loader = make_dataloader(config,
                                         mode='val',
                                         distributed=True,
                                         num_replicas=world_size,
                                         rank=rank)

        batch_size = world_size * (sum(config.TRAIN.BATCH_IMAGES)
                                   if isinstance(config.TRAIN.BATCH_IMAGES, list)
                                   else config.TRAIN.BATCH_IMAGES)
        if config.TRAIN.GRAD_ACCUMULATE_STEPS > 1:
            batch_size = batch_size * config.TRAIN.GRAD_ACCUMULATE_STEPS
        base_lr = config.TRAIN.LR * batch_size
        optimizer_grouped_parameters = [{'params': [p for n, p in model.named_parameters() if _k in n],
                                         'lr': base_lr * _lr_mult}
                                        for _k, _lr_mult in config.TRAIN.LR_MULT]
        optimizer_grouped_parameters.append({'params': [p for n, p in model.named_parameters()
                                                        if all([_k not in n for _k, _ in config.TRAIN.LR_MULT])]})
        if config.TRAIN.OPTIMIZER == 'SGD':
            optimizer = optim.SGD(optimizer_grouped_parameters,
                                  lr=config.TRAIN.LR * batch_size,
                                  momentum=config.TRAIN.MOMENTUM,
                                  weight_decay=config.TRAIN.WD)
        elif config.TRAIN.OPTIMIZER == 'Adam':
            optimizer = optim.Adam(optimizer_grouped_parameters,
                                   lr=config.TRAIN.LR * batch_size,
                                   weight_decay=config.TRAIN.WD)
        elif config.TRAIN.OPTIMIZER == 'AdamW':
            optimizer = AdamW(optimizer_grouped_parameters,
                              lr=config.TRAIN.LR * batch_size,
                              betas=(0.9, 0.999),
                              eps=1e-6,
                              weight_decay=config.TRAIN.WD,
                              correct_bias=True)
        else:
            raise ValueError('Not support optimizer {}!'.format(config.TRAIN.OPTIMIZER))
        total_gpus = world_size

    else:
        #os.environ['CUDA_VISIBLE_DEVICES'] = config.GPUS
        model = eval(config.MODULE)(config)
        summary_parameters(model, logger)
        shutil.copy(args.cfg, final_output_path)
        shutil.copy(inspect.getfile(eval(config.MODULE)), final_output_path)
        num_gpus = len(config.GPUS.split(','))
        assert num_gpus <= 1 or (not config.TRAIN.FP16), "Not support fp16 with torch.nn.DataParallel. " \
                                                         "Please use amp.parallel.DistributedDataParallel instead."
        total_gpus = num_gpus
        rank = None
        writer = SummaryWriter(log_dir=args.log_dir) if args.log_dir is not None else None

        # model
        if num_gpus > 1:
            model = torch.nn.DataParallel(model, device_ids=[int(d) for d in config.GPUS.split(',')]).cuda()
        else:
            torch.cuda.set_device(int(config.GPUS))
            model.cuda()

        # loader
        if isinstance(config.DATASET, list):
            train_loaders = make_dataloaders(config, mode='train', distributed=False)
            val_loaders = make_dataloaders(config, mode='val', distributed=False)
            train_loader = MultiTaskDataLoader(train_loaders)
            val_loader = MultiTaskDataLoader(val_loaders)
        else:
            train_loader = make_dataloader(config, mode='train', distributed=False)
            val_loader = make_dataloader(config, mode='val', distributed=False)
        train_sampler = None

        batch_size = num_gpus * (sum(config.TRAIN.BATCH_IMAGES) if isinstance(config.TRAIN.BATCH_IMAGES, list)
                                 else config.TRAIN.BATCH_IMAGES)
        if config.TRAIN.GRAD_ACCUMULATE_STEPS > 1:
            batch_size = batch_size * config.TRAIN.GRAD_ACCUMULATE_STEPS
        base_lr = config.TRAIN.LR * batch_size
        optimizer_grouped_parameters = [{'params': [p for n, p in model.named_parameters() if _k in n],
                                         'lr': base_lr * _lr_mult}
                                        for _k, _lr_mult in config.TRAIN.LR_MULT]
        optimizer_grouped_parameters.append({'params': [p for n, p in model.named_parameters()
                                                        if all([_k not in n for _k, _ in config.TRAIN.LR_MULT])]})

        if config.TRAIN.OPTIMIZER == 'SGD':
            optimizer = optim.SGD(optimizer_grouped_parameters,
                                  lr=config.TRAIN.LR * batch_size,
                                  momentum=config.TRAIN.MOMENTUM,
                                  weight_decay=config.TRAIN.WD)
        elif config.TRAIN.OPTIMIZER == 'Adam':
            optimizer = optim.Adam(optimizer_grouped_parameters,
                                   lr=config.TRAIN.LR * batch_size,
                                   weight_decay=config.TRAIN.WD)
        elif config.TRAIN.OPTIMIZER == 'AdamW':
            optimizer = AdamW(optimizer_grouped_parameters,
                              lr=config.TRAIN.LR * batch_size,
                              betas=(0.9, 0.999),
                              eps=1e-6,
                              weight_decay=config.TRAIN.WD,
                              correct_bias=True)
        else:
            raise ValueError('Not support optimizer {}!'.format(config.TRAIN.OPTIMIZER))

    # partial load pretrain state dict
    if config.NETWORK.PARTIAL_PRETRAIN != "":
        pretrain_state_dict = torch.load(config.NETWORK.PARTIAL_PRETRAIN, map_location=lambda storage, loc: storage)['state_dict']
        prefix_change = [prefix_change.split('->') for prefix_change in config.NETWORK.PARTIAL_PRETRAIN_PREFIX_CHANGES]
        if len(prefix_change) > 0:
            pretrain_state_dict_parsed = {}
            for k, v in pretrain_state_dict.items():
                no_match = True
                for pretrain_prefix, new_prefix in prefix_change:
                    if k.startswith(pretrain_prefix):
                        k = new_prefix + k[len(pretrain_prefix):]
                        pretrain_state_dict_parsed[k] = v
                        no_match = False
                        break
                if no_match:
                    pretrain_state_dict_parsed[k] = v
            pretrain_state_dict = pretrain_state_dict_parsed
        # FM edit: introduce alternative initialisations
        if config.NETWORK.INITIALISATION=='hybrid':
            smart_hybrid_partial_load_model_state_dict(model, pretrain_state_dict)
        elif config.NETWORK.INITIALISATION=='skip':
            smart_skip_partial_load_model_state_dict(model, pretrain_state_dict)
        else:
            smart_partial_load_model_state_dict(model, pretrain_state_dict)

    # metrics
    metric_kwargs = {'allreduce': args.dist,
                     'num_replicas': world_size if args.dist else 1}
    train_metrics_list = []
    val_metrics_list = []
    if config.NETWORK.WITH_REL_LOSS:
        train_metrics_list.append(retrieval_metrics.RelationshipAccuracy(**metric_kwargs))
        val_metrics_list.append(retrieval_metrics.RelationshipAccuracy(**metric_kwargs))
    if config.NETWORK.WITH_MLM_LOSS:
        if config.MODULE == 'ResNetVLBERTForPretrainingMultitask':
            train_metrics_list.append(retrieval_metrics.MLMAccuracyWVC(**metric_kwargs))
            train_metrics_list.append(retrieval_metrics.MLMAccuracyAUX(**metric_kwargs))
            val_metrics_list.append(retrieval_metrics.MLMAccuracyWVC(**metric_kwargs))
            val_metrics_list.append(retrieval_metrics.MLMAccuracyAUX(**metric_kwargs))
        else:
            train_metrics_list.append(retrieval_metrics.MLMAccuracy(**metric_kwargs))
            val_metrics_list.append(retrieval_metrics.MLMAccuracy(**metric_kwargs))
    if config.NETWORK.WITH_MVRC_LOSS:
        train_metrics_list.append(retrieval_metrics.MVRCAccuracy(**metric_kwargs))
        val_metrics_list.append(retrieval_metrics.MVRCAccuracy(**metric_kwargs))
    for output_name, display_name in config.TRAIN.LOSS_LOGGERS:
        train_metrics_list.append(retrieval_metrics.LossLogger(output_name, display_name=display_name, **metric_kwargs))
        val_metrics_list.append(retrieval_metrics.LossLogger(output_name, display_name=display_name, **metric_kwargs))

    train_metrics = CompositeEvalMetric()
    val_metrics = CompositeEvalMetric()
    for child_metric in train_metrics_list:
        train_metrics.add(child_metric)
    for child_metric in val_metrics_list:
        val_metrics.add(child_metric)

    # epoch end callbacks
    epoch_end_callbacks = []
    if (rank is None) or (rank == 0):
        epoch_end_callbacks = [Checkpoint(model_prefix, config.CHECKPOINT_FREQUENT)]
    host_metric_name = 'MLMAcc' if not config.MODULE == 'ResNetVLBERTForPretrainingMultitask' else 'MLMAccWVC'
    validation_monitor = ValidationMonitor(do_validation, val_loader, val_metrics,
                                           host_metric_name=host_metric_name)

    # optimizer initial lr before
    for group in optimizer.param_groups:
        group.setdefault('initial_lr', group['lr'])

    # resume/auto-resume
    if rank is None or rank == 0:
        smart_resume(model, optimizer, validation_monitor, config, model_prefix, logger)
    if args.dist:
        begin_epoch = torch.tensor(config.TRAIN.BEGIN_EPOCH).cuda()
        distributed.broadcast(begin_epoch, src=0)
        config.TRAIN.BEGIN_EPOCH = begin_epoch.item()

    # batch end callbacks
    batch_size = len(config.GPUS.split(',')) * (sum(config.TRAIN.BATCH_IMAGES)
                                                if isinstance(config.TRAIN.BATCH_IMAGES, list)
                                                else config.TRAIN.BATCH_IMAGES)
    batch_end_callbacks = [Speedometer(batch_size, config.LOG_FREQUENT,
                                       batches_per_epoch=len(train_loader),
                                       epochs=config.TRAIN.END_EPOCH - config.TRAIN.BEGIN_EPOCH)]

    # setup lr step and lr scheduler
    if config.TRAIN.LR_SCHEDULE == 'plateau':
        print("Warning: not support resuming on plateau lr schedule!")
        lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                                  mode='max',
                                                                  factor=config.TRAIN.LR_FACTOR,
                                                                  patience=1,
                                                                  verbose=True,
                                                                  threshold=1e-4,
                                                                  threshold_mode='rel',
                                                                  cooldown=2,
                                                                  min_lr=0,
                                                                  eps=1e-8)
    elif config.TRAIN.LR_SCHEDULE == 'triangle':
        lr_scheduler = WarmupLinearSchedule(optimizer,
                                            config.TRAIN.WARMUP_STEPS if config.TRAIN.WARMUP else 0,
                                            t_total=int(config.TRAIN.END_EPOCH * len(train_loader) / config.TRAIN.GRAD_ACCUMULATE_STEPS),
                                            last_epoch=int(config.TRAIN.BEGIN_EPOCH * len(train_loader) / config.TRAIN.GRAD_ACCUMULATE_STEPS)  - 1)
    elif config.TRAIN.LR_SCHEDULE == 'step':
        lr_iters = [int(epoch * len(train_loader) / config.TRAIN.GRAD_ACCUMULATE_STEPS) for epoch in config.TRAIN.LR_STEP]
        lr_scheduler = WarmupMultiStepLR(optimizer, milestones=lr_iters, gamma=config.TRAIN.LR_FACTOR,
                                         warmup_factor=config.TRAIN.WARMUP_FACTOR,
                                         warmup_iters=config.TRAIN.WARMUP_STEPS if config.TRAIN.WARMUP else 0,
                                         warmup_method=config.TRAIN.WARMUP_METHOD,
                                         last_epoch=int(config.TRAIN.BEGIN_EPOCH * len(train_loader) / config.TRAIN.GRAD_ACCUMULATE_STEPS)  - 1)
    else:
        raise ValueError("Not support lr schedule: {}.".format(config.TRAIN.LR_SCHEDULE))

    # broadcast parameter and optimizer state from rank 0 before training start
    if args.dist:
        for v in model.state_dict().values():
            distributed.broadcast(v, src=0)
        # for v in optimizer.state_dict().values():
        #     distributed.broadcast(v, src=0)
        best_epoch = torch.tensor(validation_monitor.best_epoch).cuda()
        best_val = torch.tensor(validation_monitor.best_val).cuda()
        distributed.broadcast(best_epoch, src=0)
        distributed.broadcast(best_val, src=0)
        validation_monitor.best_epoch = best_epoch.item()
        validation_monitor.best_val = best_val.item()

    # apex: amp fp16 mixed-precision training
    if config.TRAIN.FP16:
        # model.apply(bn_fp16_half_eval)
        model, optimizer = amp.initialize(model, optimizer,
                                          opt_level='O2',
                                          keep_batchnorm_fp32=False,
                                          loss_scale=config.TRAIN.FP16_LOSS_SCALE,
                                          max_loss_scale=128.0,
                                          min_loss_scale=128.0)
        if args.dist:
            model = Apex_DDP(model, delay_allreduce=True)

    train(model, optimizer, lr_scheduler, train_loader, train_sampler, train_metrics,
          config.TRAIN.BEGIN_EPOCH, config.TRAIN.END_EPOCH, logger,
          rank=rank, batch_end_callbacks=batch_end_callbacks, epoch_end_callbacks=epoch_end_callbacks,
          writer=writer, validation_monitor=validation_monitor, fp16=config.TRAIN.FP16,
          clip_grad_norm=config.TRAIN.CLIP_GRAD_NORM,
          gradient_accumulate_steps=config.TRAIN.GRAD_ACCUMULATE_STEPS)

    return rank, model
Esempio n. 8
0
def val_net(args, config, ckpt_path=None, save_path=None, save_name=None):
    print('test net...')
    pprint.pprint(args)
    pprint.pprint(config)
    device_ids = [int(d) for d in config.GPUS.split(',')]
    # os.environ['CUDA_VISIBLE_DEVICES'] = config.GPUS

    torch.backends.cudnn.enabled = False
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    if ckpt_path is None:
        _, train_output_path = create_logger(config.OUTPUT_PATH, args.cfg, config.DATASET.TRAIN_IMAGE_SET,
                                             split='train')
        model_prefix = os.path.join(train_output_path, config.MODEL_PREFIX)
        ckpt_path = '{}-best.model'.format(model_prefix)
        print('Use best checkpoint {}...'.format(ckpt_path))
    # if save_path is None:
    #     logger, test_output_path = create_logger(config.OUTPUT_PATH, args.cfg, config.DATASET.TEST_IMAGE_SET,
    #                                              split='test')
    #     save_path = test_output_path
    # if not os.path.exists(save_path):
    #     os.makedirs(save_path)
    # shutil.copy2(ckpt_path,
    #              os.path.join(save_path, '{}_test_ckpt_{}.model'.format(config.MODEL_PREFIX, config.DATASET.TASK)))

    # get network
    model = eval(config.MODULE)(config)

    if hasattr(model, 'setup_adapter'):
        model.setup_adapter()

    if len(device_ids) > 1:
        model = torch.nn.DataParallel(model, device_ids=device_ids).cuda()
    else:
        torch.cuda.set_device(device_ids[0])
        model = model.cuda()
    checkpoint = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
    smart_load_model_state_dict(model, checkpoint['state_dict'])

    # loader
    test_loader = make_dataloader(config, mode='val', distributed=False)
    test_dataset = test_loader.dataset
    test_database = test_dataset.database

    # test
    predicts = []
    model.eval()
    cur_id = 0
    for nbatch, batch in zip(trange(len(test_loader)), test_loader):
    # for nbatch, batch in tqdm(enumerate(test_loader)):
        bs = test_loader.batch_sampler.batch_size if test_loader.batch_sampler is not None else test_loader.batch_size
        batch = to_cuda(batch)
        outputs = model(*batch[:-1])
        if outputs['label_logits'].shape[-1] == 1:
            prob = torch.sigmoid(outputs['label_logits'][:, 0]).detach().cpu().tolist()
        else:
            prob = torch.softmax(outputs['label_logits'], dim=-1)[:, 1].detach().cpu().tolist()
        
        sample_ids = batch[-1].cpu().tolist()
        targets = batch[config.DATASET.LABEL_INDEX_IN_BATCH]
        for pb, id, tg in zip(prob, sample_ids, targets):
            predicts.append({
                'id': int(id),
                'proba': float(pb),
                'label': int(pb > 0.5),
                'target': float(tg)
            })

    pred_probs = [p['proba'] for p in predicts]
    pred_labels = [p['label'] for p in predicts]
    targets = [p['target'] for p in predicts]
    
    roc_auc = roc_auc_score(targets, pred_probs)
    print(f"roc_auc: {roc_auc}")

    max_accuracy = 0.0
    best_threshold = 1e-2
    for th in range(1, 100):
        targets_idx = [int(p['target'] > 1e-2 * th) for p in predicts]
        accuracy = accuracy_score(targets_idx, pred_labels)
        if accuracy > max_accuracy:
            max_accuracy = accuracy
            best_threshold = th * 1e-2
    print(f"max accuracy: {max_accuracy}, best_threshold: {best_threshold}")
Esempio n. 9
0
def test_net(args, config, ckpt_path=None, save_path=None, save_name=None):
    print('test net...')
    pprint.pprint(args)
    pprint.pprint(config)
    device_ids = [int(d) for d in config.GPUS.split(',')]
    # os.environ['CUDA_VISIBLE_DEVICES'] = config.GPUS

    torch.backends.cudnn.enabled = False
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    if ckpt_path is None:
        _, train_output_path = create_logger(config.OUTPUT_PATH,
                                             args.cfg,
                                             config.DATASET.TRAIN_IMAGE_SET,
                                             split='train')
        model_prefix = os.path.join(train_output_path, config.MODEL_PREFIX)
        ckpt_path = '{}-best.model'.format(model_prefix)
        print('Use best checkpoint {}...'.format(ckpt_path))
    if save_path is None:
        logger, test_output_path = create_logger(config.OUTPUT_PATH,
                                                 args.cfg,
                                                 config.DATASET.TEST_IMAGE_SET,
                                                 split='test')
        save_path = test_output_path
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    shutil.copy2(
        ckpt_path,
        os.path.join(
            save_path, '{}_test_ckpt_{}.model'.format(config.MODEL_PREFIX,
                                                      config.DATASET.TASK)))

    # get network
    model = eval(config.MODULE)(config)
    if len(device_ids) > 1:
        model = torch.nn.DataParallel(model, device_ids=device_ids).cuda()
    else:
        torch.cuda.set_device(device_ids[0])
        model = model.cuda()
    checkpoint = torch.load(ckpt_path,
                            map_location=lambda storage, loc: storage)
    smart_load_model_state_dict(model, checkpoint['state_dict'])

    # loader
    test_loader = make_dataloader(config, mode='test', distributed=False)
    test_dataset = test_loader.dataset
    test_database = test_dataset.database

    # test
    q_ids = []
    answer_ids = []
    model.eval()
    cur_id = 0
    for nbatch, batch in zip(trange(len(test_loader)), test_loader):
        # for nbatch, batch in tqdm(enumerate(test_loader)):
        bs = test_loader.batch_sampler.batch_size if test_loader.batch_sampler is not None else test_loader.batch_size
        q_ids.extend([
            str(test_database[id]['annot_id'])
            for id in range(cur_id, min(cur_id + bs, len(test_database)))
        ])
        batch = to_cuda(batch)
        output = model(*batch)
        answer_ids.extend(output['label_logits'].cpu().numpy().tolist())
        cur_id += bs

    result = [q_ids, answer_ids]

    cfg_name = os.path.splitext(os.path.basename(args.cfg))[0]
    result_json_path = os.path.join(
        save_path,
        '{}_vqa2_{}.json'.format(cfg_name if save_name is None else save_name,
                                 config.DATASET.TEST_IMAGE_SET))
    with open(result_json_path, 'w') as f:
        json.dump(result, f)
    print('result json saved to {}.'.format(result_json_path))
    return result_json_path
Esempio n. 10
0
def test_translation_net(args, config, ckpt_path=None, save_path=None, save_name=None):
    print('test net...')
    pprint.pprint(args)
    pprint.pprint(config)
    device_ids = [int(d) for d in config.GPUS.split(',')]
    # os.environ['CUDA_VISIBLE_DEVICES'] = config.GPUS

    torch.backends.cudnn.enabled = False
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    if ckpt_path is None:
        _, train_output_path = create_logger(config.OUTPUT_PATH, args.cfg, config.DATASET.TRAIN_IMAGE_SET,
                                             split='train')
        model_prefix = os.path.join(train_output_path, config.MODEL_PREFIX)
        ckpt_path = '{}-best.model'.format(model_prefix)
        print('Use best checkpoint {}...'.format(ckpt_path))
    if save_path is None:
        logger, test_output_path = create_logger(config.OUTPUT_PATH, args.cfg, config.DATASET.TEST_IMAGE_SET,
                                                 split='test')
        save_path = test_output_path
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    shutil.copy2(ckpt_path,
                 os.path.join(save_path, '{}_test_ckpt_{}.model'.format(config.MODEL_PREFIX, config.DATASET.TASK)))

    # ************
    # Step 1: Select model architecture and preload trained model
    model = eval(config.MODULE)(config)
    if len(device_ids) > 1:
        model = torch.nn.DataParallel(model, device_ids=device_ids).cuda()
    else:
        torch.cuda.set_device(device_ids[0])
        model = model.cuda()
    checkpoint = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
    smart_load_model_state_dict(model, checkpoint['state_dict'])

    # ************
    # Step 2: Create dataloader to include all caption-image pairs
    test_loader = make_dataloader(config, mode='test', distributed=False)
    test_dataset = test_loader.dataset
    test_database = test_dataset.database

    # ************
    # Step 3: Run all pairs through model for inference
    caption_ids = []
    image_ids = []
    logits = []
    model.eval()
    cur_id = 0
    for nbatch, batch in zip(trange(len(test_loader)), test_loader):
        bs = test_loader.batch_sampler.batch_size if test_loader.batch_sampler is not None else test_loader.batch_size
        caption_ids.extend([test_database[id]['caption_en_index'] for id in range(cur_id, min(cur_id + bs, len(test_database)))])
        image_ids.extend([test_database[id]['caption_de_index'] for id in range(cur_id, min(cur_id + bs, len(test_database)))])
        batch = to_cuda(batch)
        output = model(*batch)
        logits.extend(F.sigmoid(output[0]['relationship_logits']).detach().cpu().tolist())
        cur_id += bs
        #TODO: remove this is just for checking
        # if nbatch>900:
        #     break
   
    # ************
    # Step 3: Store all logit results in file for later evalution       
    result = [{'caption_en_index': c_id, 'caption_de_index': i_id, 'logit': l_id} for c_id, i_id, l_id in zip(caption_ids, image_ids, logits)]
    cfg_name = os.path.splitext(os.path.basename(args.cfg))[0]
    result_json_path = os.path.join(save_path, '{}_retrieval_translation_{}.json'.format(cfg_name if save_name is None else save_name,
                                                                        config.DATASET.TEST_IMAGE_SET))
    with open(result_json_path, 'w') as f:
        json.dump(result, f)
    print('result json saved to {}.'.format(result_json_path))
    return result_json_path
Esempio n. 11
0
def test_net(args, config, ckpt_path=None, save_path=None, save_name=None):
    print('test net...')
    pprint.pprint(args)
    pprint.pprint(config)
    device_ids = [int(d) for d in config.GPUS.split(',')]
    # os.environ['CUDA_VISIBLE_DEVICES'] = config.GPUS

    torch.backends.cudnn.enabled = False
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    if ckpt_path is None:
        _, train_output_path = create_logger(config.OUTPUT_PATH, args.cfg, config.DATASET.TRAIN_IMAGE_SET,
                                             split='train')
        model_prefix = os.path.join(train_output_path, config.MODEL_PREFIX)
        ckpt_path = '{}-best.model'.format(model_prefix)
        print('Use best checkpoint {}...'.format(ckpt_path))
    if save_path is None:
        logger, test_output_path = create_logger(config.OUTPUT_PATH, args.cfg, config.DATASET.TEST_IMAGE_SET,
                                                 split='test')
        save_path = test_output_path
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    shutil.copy2(ckpt_path,
                 os.path.join(save_path, '{}_test_ckpt_{}.model'.format(config.MODEL_PREFIX, config.DATASET.TASK)))

    # get network
    model = eval(config.MODULE)(config)
    if len(device_ids) > 1:
        model = torch.nn.DataParallel(model, device_ids=device_ids).cuda()
    else:
        torch.cuda.set_device(device_ids[0])
        model = model.cuda()
    checkpoint = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
    smart_load_model_state_dict(model, checkpoint['state_dict'])

    # loader
    test_loader = make_dataloader(config, mode='test', distributed=False)

    split = args.split + 'id' if args.split == 'val' else args.split # 'val' -> 'valid'

    # test
    if config.TEST.EXCL_LEFT_RIGHT:
        precompute_test_cache = f'{args.log_dir}/pred_{split}_{ckpt_path[-10:-6]}_excl-left-right.pickle'
    else:
        precompute_test_cache = f'{args.log_dir}/pred_{split}_{ckpt_path[-10:-6]}.pickle'
    if not os.path.isdir(args.log_dir):
        os.makedirs(args.log_dir)
    pred_file = precompute_test_cache
    if not os.path.exists(precompute_test_cache):
        _ids = []
        losses = []
        predictions = []
        model.eval()
        
        if args.visualize_mask: # For mask visualization purpose
            save_dir = 'heatmap/spasen'
            if not os.path.isdir(save_dir):
                # os.mkdir(save_dir)
                os.makedirs(save_dir)

        for nbatch, batch in zip(trange(len(test_loader)), test_loader):
            _ids.extend(batch[0]) # the first input element is _id

            batch = to_cuda(batch)
            output = model(*batch)
            
            predictions.append(output['prediction'])
            losses.append(output['ans_loss'].item())

            if args.visualize_mask: # For mask visualization purpose
                mask = output['spo_fused_masks'].cpu() # torch.Size([8, 3, 14, 14])
                subj_name = output['subj_name'] # list of 8 strs
                obj_name = output['obj_name'] # list of 8 strs
                pred_name = output['pred_name'] # list of 8 strs
                im_path = output['im_path'] # list of 8 img urls

                for i in range(mask.shape[0]):
                    img, dataset = read_img(im_path[i], config.IMAGEPATH)
                    img_name = dataset + '-' + im_path[i].split('/')[-1]
                    show_cam_on_image(img, mask[i], img_name, subj_name[i], obj_name[i], pred_name[i], save_dir)

        predictions = [v.item() for v in torch.cat(predictions)]
        loss = sum(losses) / len(losses)
        pickle.dump((_ids, predictions, loss), open(pred_file, 'wb'))

    accs, loss = accuracies(pred_file, 'data/spasen/annotations.json', split)

    return accs, loss
Esempio n. 12
0
def train_net(args, config):
    # setup logger
    logger, final_output_path = create_logger(config.OUTPUT_PATH,
                                              args.cfg,
                                              config.DATASET.TRAIN_IMAGE_SET,
                                              split='train')
    model_prefix = os.path.join(final_output_path, config.MODEL_PREFIX)
    if args.log_dir is None:
        args.log_dir = os.path.join(final_output_path, 'tensorboard_logs')

    # pprint.pprint(args)
    # logger.info('training args:{}\n'.format(args))
    # pprint.pprint(config)
    # logger.info('training config:{}\n'.format(pprint.pformat(config)))

    # manually set random seed
    if config.RNG_SEED > -1:
        random.seed(a=config.RNG_SEED)
        np.random.seed(config.RNG_SEED)
        torch.random.manual_seed(config.RNG_SEED)
        torch.cuda.manual_seed_all(config.RNG_SEED)
        torch.backends.cudnn.deterministic = True
        imgaug.random.seed(config.RNG_SEED)

    # cudnn
    torch.backends.cudnn.benchmark = False
    if args.cudnn_off:
        torch.backends.cudnn.enabled = False

    if args.dist:
        model = eval(config.MODULE)(config)
        local_rank = int(os.environ.get('LOCAL_RANK') or 0)
        config.GPUS = str(local_rank)
        torch.cuda.set_device(local_rank)
        master_address = os.environ['MASTER_ADDR']
        master_port = int(os.environ['MASTER_PORT'] or 23456)
        world_size = int(os.environ['WORLD_SIZE'] or 1)
        rank = int(os.environ['RANK'] or 0)

        if rank == 0:
            pprint.pprint(args)
            logger.info('training args:{}\n'.format(args))
            pprint.pprint(config)
            logger.info('training config:{}\n'.format(pprint.pformat(config)))

        if args.slurm:
            distributed.init_process_group(backend='nccl')
        else:
            try:
                distributed.init_process_group(
                    backend='nccl',
                    init_method='tcp://{}:{}'.format(master_address,
                                                     master_port),
                    world_size=world_size,
                    rank=rank,
                    group_name='mtorch')
            except RuntimeError:
                pass
        print(
            f'native distributed, size: {world_size}, rank: {rank}, local rank: {local_rank}'
        )
        torch.cuda.set_device(local_rank)
        config.GPUS = str(local_rank)
        model = model.cuda()
        if not config.TRAIN.FP16:
            model = DDP(model,
                        device_ids=[local_rank],
                        output_device=local_rank,
                        find_unused_parameters=True)

        if rank == 0:
            summary_parameters(
                model.module if isinstance(
                    model, torch.nn.parallel.DistributedDataParallel) else
                model, logger)
            shutil.copy(args.cfg, final_output_path)
            shutil.copy(inspect.getfile(eval(config.MODULE)),
                        final_output_path)

        writer = None
        if args.log_dir is not None:
            tb_log_dir = os.path.join(args.log_dir, 'rank{}'.format(rank))
            if not os.path.exists(tb_log_dir):
                os.makedirs(tb_log_dir)
            writer = SummaryWriter(log_dir=tb_log_dir)

        batch_size = world_size * (sum(config.TRAIN.BATCH_IMAGES) if
                                   isinstance(config.TRAIN.BATCH_IMAGES, list)
                                   else config.TRAIN.BATCH_IMAGES)
        if config.TRAIN.GRAD_ACCUMULATE_STEPS > 1:
            batch_size = batch_size * config.TRAIN.GRAD_ACCUMULATE_STEPS
        base_lr = config.TRAIN.LR * batch_size
        optimizer_grouped_parameters = [{
            'params': [p for n, p in model.named_parameters() if _k in n],
            'lr':
            base_lr * _lr_mult
        } for _k, _lr_mult in config.TRAIN.LR_MULT]
        optimizer_grouped_parameters.append({
            'params': [
                p for n, p in model.named_parameters()
                if all([_k not in n for _k, _ in config.TRAIN.LR_MULT])
            ]
        })
        if config.TRAIN.OPTIMIZER == 'SGD':
            optimizer = optim.SGD(optimizer_grouped_parameters,
                                  lr=config.TRAIN.LR * batch_size,
                                  momentum=config.TRAIN.MOMENTUM,
                                  weight_decay=config.TRAIN.WD)
        elif config.TRAIN.OPTIMIZER == 'Adam':
            optimizer = optim.Adam(optimizer_grouped_parameters,
                                   lr=config.TRAIN.LR * batch_size,
                                   weight_decay=config.TRAIN.WD)
        elif config.TRAIN.OPTIMIZER == 'AdamW':
            optimizer = AdamW(optimizer_grouped_parameters,
                              lr=config.TRAIN.LR * batch_size,
                              betas=(0.9, 0.999),
                              eps=1e-6,
                              weight_decay=config.TRAIN.WD,
                              correct_bias=True)
        else:
            raise ValueError('Not support optimizer {}!'.format(
                config.TRAIN.OPTIMIZER))
        total_gpus = world_size

        train_loader, train_sampler = make_dataloader(config,
                                                      mode='train',
                                                      distributed=True,
                                                      num_replicas=world_size,
                                                      rank=rank,
                                                      expose_sampler=True)
        val_loader = make_dataloader(config,
                                     mode='val',
                                     distributed=True,
                                     num_replicas=world_size,
                                     rank=rank)

    else:
        pprint.pprint(args)
        logger.info('training args:{}\n'.format(args))
        pprint.pprint(config)
        logger.info('training config:{}\n'.format(pprint.pformat(config)))

        #os.environ['CUDA_VISIBLE_DEVICES'] = config.GPUS
        model = eval(config.MODULE)(config)
        summary_parameters(model, logger)
        shutil.copy(args.cfg, final_output_path)
        shutil.copy(inspect.getfile(eval(config.MODULE)), final_output_path)
        num_gpus = len(config.GPUS.split(','))
        # assert num_gpus <= 1 or (not config.TRAIN.FP16), "Not support fp16 with torch.nn.DataParallel. " \
        #                                                  "Please use amp.parallel.DistributedDataParallel instead."
        if num_gpus > 1 and config.TRAIN.FP16:
            logger.warning("Not support fp16 with torch.nn.DataParallel.")
            config.TRAIN.FP16 = False

        total_gpus = num_gpus
        rank = None
        writer = SummaryWriter(
            log_dir=args.log_dir) if args.log_dir is not None else None

        if hasattr(model, 'setup_adapter'):
            logger.info('Setting up adapter modules!')
            model.setup_adapter()

        # model
        if num_gpus > 1:
            model = torch.nn.DataParallel(
                model,
                device_ids=[int(d) for d in config.GPUS.split(',')]).cuda()
        else:
            torch.cuda.set_device(int(config.GPUS))
            model.cuda()

        # loader
        # train_set = 'train+val' if config.DATASET.TRAIN_WITH_VAL else 'train'
        train_loader = make_dataloader(config, mode='train', distributed=False)
        val_loader = make_dataloader(config, mode='val', distributed=False)
        train_sampler = None

        batch_size = num_gpus * (sum(config.TRAIN.BATCH_IMAGES) if isinstance(
            config.TRAIN.BATCH_IMAGES, list) else config.TRAIN.BATCH_IMAGES)
        if config.TRAIN.GRAD_ACCUMULATE_STEPS > 1:
            batch_size = batch_size * config.TRAIN.GRAD_ACCUMULATE_STEPS
        base_lr = config.TRAIN.LR * batch_size
        optimizer_grouped_parameters = [{
            'params': [p for n, p in model.named_parameters() if _k in n],
            'lr':
            base_lr * _lr_mult
        } for _k, _lr_mult in config.TRAIN.LR_MULT]
        optimizer_grouped_parameters.append({
            'params': [
                p for n, p in model.named_parameters()
                if all([_k not in n for _k, _ in config.TRAIN.LR_MULT])
            ]
        })

        if config.TRAIN.OPTIMIZER == 'SGD':
            optimizer = optim.SGD(optimizer_grouped_parameters,
                                  lr=config.TRAIN.LR * batch_size,
                                  momentum=config.TRAIN.MOMENTUM,
                                  weight_decay=config.TRAIN.WD)
        elif config.TRAIN.OPTIMIZER == 'Adam':
            optimizer = optim.Adam(optimizer_grouped_parameters,
                                   lr=config.TRAIN.LR * batch_size,
                                   weight_decay=config.TRAIN.WD)
        elif config.TRAIN.OPTIMIZER == 'AdamW':
            optimizer = AdamW(optimizer_grouped_parameters,
                              lr=config.TRAIN.LR * batch_size,
                              betas=(0.9, 0.999),
                              eps=1e-6,
                              weight_decay=config.TRAIN.WD,
                              correct_bias=True)
        else:
            raise ValueError('Not support optimizer {}!'.format(
                config.TRAIN.OPTIMIZER))

    # partial load pretrain state dict
    if config.NETWORK.PARTIAL_PRETRAIN != "":
        pretrain_state_dict = torch.load(
            config.NETWORK.PARTIAL_PRETRAIN,
            map_location=lambda storage, loc: storage)['state_dict']
        prefix_change = [
            prefix_change.split('->')
            for prefix_change in config.NETWORK.PARTIAL_PRETRAIN_PREFIX_CHANGES
        ]
        if len(prefix_change) > 0:
            pretrain_state_dict_parsed = {}
            for k, v in pretrain_state_dict.items():
                no_match = True
                for pretrain_prefix, new_prefix in prefix_change:
                    if k.startswith(pretrain_prefix):
                        k = new_prefix + k[len(pretrain_prefix):]
                        pretrain_state_dict_parsed[k] = v
                        no_match = False
                        break
                if no_match:
                    pretrain_state_dict_parsed[k] = v
            pretrain_state_dict = pretrain_state_dict_parsed
        smart_partial_load_model_state_dict(model, pretrain_state_dict)

    # pretrained classifier
    # if config.NETWORK.CLASSIFIER_PRETRAINED:
    #     print('Initializing classifier weight from pretrained word embeddings...')
    #     answers_word_embed = []
    #     for k, v in model.state_dict().items():
    #         if 'word_embeddings.weight' in k:
    #             word_embeddings = v.detach().clone()
    #             break
    #     for answer in train_loader.dataset.answer_vocab:
    #         a_tokens = train_loader.dataset.tokenizer.tokenize(answer)
    #         a_ids = train_loader.dataset.tokenizer.convert_tokens_to_ids(a_tokens)
    #         a_word_embed = (torch.stack([word_embeddings[a_id] for a_id in a_ids], dim=0)).mean(dim=0)
    #         answers_word_embed.append(a_word_embed)
    #     answers_word_embed_tensor = torch.stack(answers_word_embed, dim=0)
    #     for name, module in model.named_modules():
    #         if name.endswith('final_mlp'):
    #             module[-1].weight.data = answers_word_embed_tensor.to(device=module[-1].weight.data.device)

    # metrics
    train_metrics_list = [
        cls_metrics.Accuracy(allreduce=args.dist,
                             num_replicas=world_size if args.dist else 1)
    ]
    val_metrics_list = [
        cls_metrics.Accuracy(allreduce=args.dist,
                             num_replicas=world_size if args.dist else 1),
        cls_metrics.RocAUC(allreduce=args.dist,
                           num_replicas=world_size if args.dist else 1)
    ]
    for output_name, display_name in config.TRAIN.LOSS_LOGGERS:
        train_metrics_list.append(
            cls_metrics.LossLogger(
                output_name,
                display_name=display_name,
                allreduce=args.dist,
                num_replicas=world_size if args.dist else 1))

    train_metrics = CompositeEvalMetric()
    val_metrics = CompositeEvalMetric()
    for child_metric in train_metrics_list:
        train_metrics.add(child_metric)
    for child_metric in val_metrics_list:
        val_metrics.add(child_metric)

    # epoch end callbacks
    epoch_end_callbacks = []
    if (rank is None) or (rank == 0):
        epoch_end_callbacks = [
            Checkpoint(model_prefix, config.CHECKPOINT_FREQUENT)
        ]
    validation_monitor = ValidationMonitor(
        do_validation,
        val_loader,
        val_metrics,
        host_metric_name='RocAUC',
        label_index_in_batch=config.DATASET.LABEL_INDEX_IN_BATCH,
        model_dir=os.path.dirname(model_prefix))

    # optimizer initial lr before
    for group in optimizer.param_groups:
        group.setdefault('initial_lr', group['lr'])

    # resume/auto-resume
    if rank is None or rank == 0:
        smart_resume(model, optimizer, validation_monitor, config,
                     model_prefix, logger)
    if args.dist:
        begin_epoch = torch.tensor(config.TRAIN.BEGIN_EPOCH).cuda()
        distributed.broadcast(begin_epoch, src=0)
        config.TRAIN.BEGIN_EPOCH = begin_epoch.item()

    # batch end callbacks
    batch_size = len(config.GPUS.split(',')) * config.TRAIN.BATCH_IMAGES
    batch_end_callbacks = [
        Speedometer(batch_size,
                    config.LOG_FREQUENT,
                    batches_per_epoch=len(train_loader),
                    epochs=config.TRAIN.END_EPOCH - config.TRAIN.BEGIN_EPOCH)
    ]

    # setup lr step and lr scheduler
    if config.TRAIN.LR_SCHEDULE == 'plateau':
        print("Warning: not support resuming on plateau lr schedule!")
        lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='max',
            factor=config.TRAIN.LR_FACTOR,
            patience=1,
            verbose=True,
            threshold=1e-4,
            threshold_mode='rel',
            cooldown=2,
            min_lr=0,
            eps=1e-8)
    elif config.TRAIN.LR_SCHEDULE == 'triangle':
        lr_scheduler = WarmupLinearSchedule(
            optimizer,
            config.TRAIN.WARMUP_STEPS if config.TRAIN.WARMUP else 0,
            t_total=int(config.TRAIN.END_EPOCH * len(train_loader) /
                        config.TRAIN.GRAD_ACCUMULATE_STEPS),
            last_epoch=int(config.TRAIN.BEGIN_EPOCH * len(train_loader) /
                           config.TRAIN.GRAD_ACCUMULATE_STEPS) - 1)
    elif config.TRAIN.LR_SCHEDULE == 'step':
        lr_iters = [
            int(epoch * len(train_loader) / config.TRAIN.GRAD_ACCUMULATE_STEPS)
            for epoch in config.TRAIN.LR_STEP
        ]
        lr_scheduler = WarmupMultiStepLR(
            optimizer,
            milestones=lr_iters,
            gamma=config.TRAIN.LR_FACTOR,
            warmup_factor=config.TRAIN.WARMUP_FACTOR,
            warmup_iters=config.TRAIN.WARMUP_STEPS
            if config.TRAIN.WARMUP else 0,
            warmup_method=config.TRAIN.WARMUP_METHOD,
            last_epoch=int(config.TRAIN.BEGIN_EPOCH * len(train_loader) /
                           config.TRAIN.GRAD_ACCUMULATE_STEPS) - 1)
    else:
        raise ValueError("Not support lr schedule: {}.".format(
            config.TRAIN.LR_SCHEDULE))

    if config.TRAIN.SWA:
        assert config.TRAIN.SWA_START_EPOCH < config.TRAIN.END_EPOCH
        if not config.TRAIN.DEBUG:
            true_epoch_step = len(
                train_loader) / config.TRAIN.GRAD_ACCUMULATE_STEPS
        else:
            true_epoch_step = 50
        step_per_cycle = config.TRAIN.SWA_EPOCH_PER_CYCLE * true_epoch_step

        # swa_scheduler = torch.optim.lr_scheduler.CyclicLR(
        #     optimizer,
        #     base_lr=config.TRAIN.SWA_MIN_LR * batch_size,
        #     max_lr=config.TRAIN.SWA_MAX_LR * batch_size,
        #     cycle_momentum=False,
        #     step_size_up=10,
        #     step_size_down=step_per_cycle - 10)

        anneal_steps = max(
            1, (config.TRAIN.END_EPOCH - config.TRAIN.SWA_START_EPOCH) //
            4) * step_per_cycle
        anneal_steps = int(anneal_steps)
        swa_scheduler = SWALR(optimizer,
                              anneal_epochs=anneal_steps,
                              anneal_strategy='linear',
                              swa_lr=config.TRAIN.SWA_MAX_LR * batch_size)
    else:
        swa_scheduler = None

    if config.TRAIN.ROC_STAR:
        assert config.TRAIN.ROC_START_EPOCH < config.TRAIN.END_EPOCH
        roc_star = RocStarLoss(
            delta=2.0,
            sample_size=config.TRAIN.ROC_SAMPLE_SIZE,
            sample_size_gamma=config.TRAIN.ROC_SAMPLE_SIZE * 2,
            update_gamma_each=config.TRAIN.ROC_SAMPLE_SIZE,
        )
    else:
        roc_star = None

    # broadcast parameter and optimizer state from rank 0 before training start
    if args.dist:
        for v in model.state_dict().values():
            distributed.broadcast(v, src=0)
        # for v in optimizer.state_dict().values():
        #     distributed.broadcast(v, src=0)
        best_epoch = torch.tensor(validation_monitor.best_epoch).cuda()
        best_val = torch.tensor(validation_monitor.best_val).cuda()
        distributed.broadcast(best_epoch, src=0)
        distributed.broadcast(best_val, src=0)
        validation_monitor.best_epoch = best_epoch.item()
        validation_monitor.best_val = best_val.item()

    # apex: amp fp16 mixed-precision training
    if config.TRAIN.FP16:
        # model.apply(bn_fp16_half_eval)
        model, optimizer = amp.initialize(
            model,
            optimizer,
            opt_level='O2',
            keep_batchnorm_fp32=False,
            loss_scale=config.TRAIN.FP16_LOSS_SCALE,
            min_loss_scale=32.0)
        if args.dist:
            model = Apex_DDP(model, delay_allreduce=True)

    # NOTE: final_model == model if not using SWA, else final_model == AveragedModel(model)
    final_model = train(
        model,
        optimizer,
        lr_scheduler,
        train_loader,
        train_sampler,
        train_metrics,
        config.TRAIN.BEGIN_EPOCH,
        config.TRAIN.END_EPOCH,
        logger,
        fp16=config.TRAIN.FP16,
        rank=rank,
        writer=writer,
        batch_end_callbacks=batch_end_callbacks,
        epoch_end_callbacks=epoch_end_callbacks,
        validation_monitor=validation_monitor,
        clip_grad_norm=config.TRAIN.CLIP_GRAD_NORM,
        gradient_accumulate_steps=config.TRAIN.GRAD_ACCUMULATE_STEPS,
        ckpt_path=config.TRAIN.CKPT_PATH,
        swa_scheduler=swa_scheduler,
        swa_start_epoch=config.TRAIN.SWA_START_EPOCH,
        swa_cycle_epoch=config.TRAIN.SWA_EPOCH_PER_CYCLE,
        swa_use_scheduler=config.TRAIN.SWA_SCHEDULE,
        roc_star=roc_star,
        roc_star_start_epoch=config.TRAIN.ROC_START_EPOCH,
        roc_interleave=config.TRAIN.ROC_INTERLEAVE,
        debug=config.TRAIN.DEBUG,
    )

    return rank, final_model
Esempio n. 13
0
def train_net(args, config):
    # setup logger
    logger, final_output_path = create_logger(config.OUTPUT_PATH,
                                              args.cfg,
                                              config.DATASET[0].TRAIN_IMAGE_SET if isinstance(config.DATASET, list)
                                              else config.DATASET.TRAIN_IMAGE_SET,
                                              split='train')
    model_prefix = os.path.join(final_output_path, config.MODEL_PREFIX)
    if args.log_dir is None:
        args.log_dir = os.path.join(final_output_path, 'tensorboard_logs')

    pprint.pprint(args)
    logger.info('training args:{}\n'.format(args))
    pprint.pprint(config)
    logger.info('training config:{}\n'.format(pprint.pformat(config)))

    # manually set random seed
    if config.RNG_SEED > -1:
        random.seed(config.RNG_SEED)
        np.random.seed(config.RNG_SEED)
        torch.random.manual_seed(config.RNG_SEED)
        torch.cuda.manual_seed_all(config.RNG_SEED)

    # cudnn
    torch.backends.cudnn.benchmark = False
    if args.cudnn_off:
        torch.backends.cudnn.enabled = False

    if args.dist:
        model = eval(config.MODULE)(config)
        local_rank = int(os.environ.get('LOCAL_RANK') or 0)
        config.GPUS = str(local_rank)
        torch.cuda.set_device(local_rank)
        master_address = os.environ['MASTER_ADDR']
        master_port = int(os.environ['MASTER_PORT'] or 23456)
        world_size = int(os.environ['WORLD_SIZE'] or 1)
        rank = int(os.environ['RANK'] or 0)
        if args.slurm:
            distributed.init_process_group(backend='nccl')
        else:
            distributed.init_process_group(
                backend='nccl',
                init_method='tcp://{}:{}'.format(master_address, master_port),
                world_size=world_size,
                rank=rank,
                group_name='mtorch')
        print(f'native distributed, size: {world_size}, rank: {rank}, local rank: {local_rank}')
        torch.cuda.set_device(local_rank)
        config.GPUS = str(local_rank)
        model = model.cuda()
        if not config.TRAIN.FP16:
            model = DDP(model, device_ids=[local_rank], output_device=local_rank)

        if rank == 0:
            summary_parameters(model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model,
                               logger)
            shutil.copy(args.cfg, final_output_path)
            shutil.copy(inspect.getfile(eval(config.MODULE)), final_output_path)

        writer = None
        if args.log_dir is not None:
            tb_log_dir = os.path.join(args.log_dir, 'rank{}'.format(rank))
            if not os.path.exists(tb_log_dir):
                os.makedirs(tb_log_dir)
            writer = SummaryWriter(log_dir=tb_log_dir)

        if isinstance(config.DATASET, list):
            train_loaders_and_samplers = make_dataloaders(config,
                                                          mode='train',
                                                          distributed=True,
                                                          num_replicas=world_size,
                                                          rank=rank,
                                                          expose_sampler=True)

            train_loader = MultiTaskDataLoader([loader for loader, _ in train_loaders_and_samplers])
            train_sampler = train_loaders_and_samplers[0][1]
        else:
            train_loader, train_sampler = make_dataloader(config,
                                                          mode='train',
                                                          distributed=True,
                                                          num_replicas=world_size,
                                                          rank=rank,
                                                          expose_sampler=True)

        batch_size = world_size * (sum(config.TRAIN.BATCH_IMAGES)
                                   if isinstance(config.TRAIN.BATCH_IMAGES, list)
                                   else config.TRAIN.BATCH_IMAGES)
        if config.TRAIN.GRAD_ACCUMULATE_STEPS > 1:
            batch_size = batch_size * config.TRAIN.GRAD_ACCUMULATE_STEPS
        base_lr = config.TRAIN.LR * batch_size
        optimizer_grouped_parameters = [{'params': [p for n, p in model.named_parameters() if _k in n],
                                         'lr': base_lr * _lr_mult}
                                        for _k, _lr_mult in config.TRAIN.LR_MULT]
        optimizer_grouped_parameters.append({'params': [p for n, p in model.named_parameters()
                                                        if all([_k not in n for _k, _ in config.TRAIN.LR_MULT])]})
        if config.TRAIN.OPTIMIZER == 'SGD':
            optimizer = optim.SGD(optimizer_grouped_parameters,
                                  lr=config.TRAIN.LR * batch_size,
                                  momentum=config.TRAIN.MOMENTUM,
                                  weight_decay=config.TRAIN.WD)
        elif config.TRAIN.OPTIMIZER == 'Adam':
            optimizer = optim.Adam(optimizer_grouped_parameters,
                                   lr=config.TRAIN.LR * batch_size,
                                   weight_decay=config.TRAIN.WD)
        elif config.TRAIN.OPTIMIZER == 'AdamW':
            optimizer = AdamW(optimizer_grouped_parameters,
                              lr=config.TRAIN.LR * batch_size,
                              betas=(0.9, 0.999),
                              eps=1e-6,
                              weight_decay=config.TRAIN.WD,
                              correct_bias=True)
        else:
            raise ValueError('Not support optimizer {}!'.format(config.TRAIN.OPTIMIZER))
        total_gpus = world_size

    else:
        #os.environ['CUDA_VISIBLE_DEVICES'] = config.GPUS
        model = eval(config.MODULE)(config)
        summary_parameters(model, logger)
        shutil.copy(args.cfg, final_output_path)
        shutil.copy(inspect.getfile(eval(config.MODULE)), final_output_path)
        num_gpus = len(config.GPUS.split(','))
        assert num_gpus <= 1 or (not config.TRAIN.FP16), "Not support fp16 with torch.nn.DataParallel. " \
                                                         "Please use amp.parallel.DistributedDataParallel instead."
        total_gpus = num_gpus
        rank = None
        writer = SummaryWriter(log_dir=args.log_dir) if args.log_dir is not None else None

        # model
        if num_gpus > 1:
            model = torch.nn.DataParallel(model, device_ids=[int(d) for d in config.GPUS.split(',')]).cuda()
        else:
            torch.cuda.set_device(int(config.GPUS))
            model.cuda()

        # loader
        if isinstance(config.DATASET, list):
            train_loaders = make_dataloaders(config, mode='train', distributed=False)
            train_loader = MultiTaskDataLoader(train_loaders)
        else:
            train_loader = make_dataloader(config, mode='train', distributed=False)
        train_sampler = None

        batch_size = num_gpus * (sum(config.TRAIN.BATCH_IMAGES) if isinstance(config.TRAIN.BATCH_IMAGES, list)
                                 else config.TRAIN.BATCH_IMAGES)

    # partial load pretrain state dict
    if config.NETWORK.PARTIAL_PRETRAIN != "":
        pretrain_state_dict = torch.load(config.NETWORK.PARTIAL_PRETRAIN, map_location=lambda storage, loc: storage)['state_dict']
        prefix_change = [prefix_change.split('->') for prefix_change in config.NETWORK.PARTIAL_PRETRAIN_PREFIX_CHANGES]
        if len(prefix_change) > 0:
            pretrain_state_dict_parsed = {}
            for k, v in pretrain_state_dict.items():
                no_match = True
                for pretrain_prefix, new_prefix in prefix_change:
                    if k.startswith(pretrain_prefix):
                        k = new_prefix + k[len(pretrain_prefix):]
                        pretrain_state_dict_parsed[k] = v
                        no_match = False
                        break
                if no_match:
                    pretrain_state_dict_parsed[k] = v
            pretrain_state_dict = pretrain_state_dict_parsed
        smart_partial_load_model_state_dict(model, pretrain_state_dict)


    # batch end callbacks
    batch_size = len(config.GPUS.split(',')) * (sum(config.TRAIN.BATCH_IMAGES)
                                                if isinstance(config.TRAIN.BATCH_IMAGES, list)
                                                else config.TRAIN.BATCH_IMAGES)
    batch_end_callbacks = [Speedometer(batch_size, config.LOG_FREQUENT,
                                       batches_per_epoch=len(train_loader),
                                       epochs=1)]

    # broadcast parameter from rank 0 before training start
    if args.dist:
        for v in model.state_dict().values():
            distributed.broadcast(v, src=0)

    # set net to train mode
    model.eval()


    # init end time
    end_time = time.time()

    # Parameter to pass to batch_end_callback
    BatchEndParam = namedtuple('BatchEndParams',
                               ['epoch',
                                'nbatch',
                                'rank',
                                'add_step',
                                'data_in_time',
                                'data_transfer_time',
                                'forward_time',
                                'backward_time',
                                'optimizer_time',
                                'metric_time',
                                'eval_metric',
                                'locals'])

    def _multiple_callbacks(callbacks, *args, **kwargs):
        """Sends args and kwargs to any configured callbacks.
        This handles the cases where the 'callbacks' variable
        is ``None``, a single function, or a list.
        """
        if isinstance(callbacks, list):
            for cb in callbacks:
                cb(*args, **kwargs)
            return
        if callbacks:
            callbacks(*args, **kwargs)

    # initialize Fisher
    fisher = {}
    for n, p in model.named_parameters():
        fisher[n] = p.new_zeros(p.size())
        p.requires_grad = True
        p.retain_grad()

    # training
    for nbatch, batch in enumerate(train_loader):
        model.zero_grad()
        global_steps = len(train_loader) + nbatch
        os.environ['global_steps'] = str(global_steps)

        # record time
        data_in_time = time.time() - end_time

        # transfer data to GPU
        data_transfer_time = time.time()
        batch = to_cuda(batch)
        data_transfer_time = time.time() - data_transfer_time

        # forward
        forward_time = time.time()
        outputs, loss = model(*batch)
        loss = loss.mean()
        forward_time = time.time() - forward_time

        # backward
        backward_time = time.time()
        loss.backward()

        backward_time = time.time() - backward_time

        for n, p in model.named_parameters():
            assert p.grad is not None, print(batch)
            fisher[n] += p.grad**2 / len(train_loader)
        batch_end_params = BatchEndParam(epoch=0, nbatch=nbatch, add_step=True, rank=rank,
                                         data_in_time=data_in_time, data_transfer_time=data_transfer_time,
                                         forward_time=forward_time, backward_time=backward_time,
                                         optimizer_time=0., metric_time=0.,
                                         eval_metric=None, locals=locals())
        _multiple_callbacks(batch_end_callbacks, batch_end_params)
    with open(os.path.join(config.EWC_STATS_PATH, "fisher"), "wb") as fisher_file:
        pickle.dump(fisher, fisher_file)
    torch.save(model.state_dict(), os.path.join(config.EWC_STATS_PATH, "params_pretrain"))