Exemplo n.º 1
0
 def _mining_epoch(self, mining_epoch_size, mining_dataset_size):
   """Take exam, collect and update positive dataset and negative dataset"""
   pos_data = RandomlyIterDataset()
   neg_data = RandomlyIterDataset()
   self.model.eval()
   meters = GroupMeters()
   with tqdm_pbar(total=mining_epoch_size) as pbar:
     for i in range(mining_epoch_size):
       message, result = self._get_result(i, meters, mode='mining')
       positive, number, backup = self._extract_info(result)
       dataset = pos_data if positive else neg_data
       if dataset.size < mining_dataset_size:
         dataset.append((number, backup))
       pbar.set_description(message)
       pbar.update()
       # When both positive and negative dataset are full, break.
       if pos_data.size >= mining_dataset_size and \
               neg_data.size >= mining_dataset_size:
         break
   logger.info(meters.format_simple('> Mining: ', compressed=False))
   self._inherit_neg_data(neg_data, self.neg_data, meters, mining_dataset_size)
   self.pos_data = pos_data
   self.neg_data = neg_data
   self._dump_meters(meters, 'mining')
   return meters
Exemplo n.º 2
0
def main():
    scenes = io.load_json(args.scene_json)['scenes']
    preds = io.load(args.preds_json)
    if isinstance(preds, dict):
        preds = list(preds.values())
    if False:
        preds = [transpose_scene(s) for s in preds]

    # flattened_objs = [o for s in scenes for o in s['objects']]
    # flattened_preds = {
    #     k: np.concatenate([np.array(p[k]) for p in preds], axis=0)
    #     for k in preds[0]
    # }
    meter = GroupMeters()

    '''
    for i, scene in tqdm_gofor(scenes, mininterval=0.5):
        for j in range(len(scene['objects'])):
            test(j, scene['objects'], preds[i], meter)
    '''
    for i, pred in tqdm_gofor(preds, mininterval=0.5):
        scene = scenes[i]
        for j in range(len(scene['objects'])):
            test(j, scene['objects'], pred, meter)

    print(meter.format_simple('Results:', compressed=False))
Exemplo n.º 3
0
def main():
    args.dump_dir = ensure_path(
        osp.join('dumps', args.dataset_name, args.desc_name, args.expr))
    args.ckpt_dir = ensure_path(osp.join(args.dump_dir, 'checkpoints'))
    args.meta_dir = ensure_path(osp.join(args.dump_dir, 'meta'))
    args.vis_dir = osp.join(args.dump_dir, 'vis', args.run_name)

    initialize_dataset(args.dataset)
    build_dataset = get_dataset_builder(args.dataset)

    dataset = build_dataset(args, configs, args.data_image_root,
                            args.data_scenes_json, args.data_questions_json)

    dataset_split = int(len(dataset) *
                        args.data_split) if args.data_split <= 1 else int(
                            args.data_split)
    train_dataset, validation_dataset = dataset.split_trainval(dataset_split)

    logger.critical('Building the model.')
    model = desc.make_model(args, train_dataset.unwrapped.vocab)

    if args.use_gpu:
        model.cuda()
        # Use the customized data parallel if applicable.
        if args.gpu_parallel:
            from jactorch.parallel import JacDataParallel
            # from jactorch.parallel import UserScatteredJacDataParallel as JacDataParallel
            model = JacDataParallel(model, device_ids=args.gpus).cuda()
        # Disable the cudnn benchmark.
        cudnn.benchmark = False

    if args.load:
        from jactorch.io import load_weights
        if load_weights(model, args.load):
            logger.critical(
                'Loaded weights from pretrained model: "{}".'.format(
                    args.load))

    from jacinle.utils.meter import GroupMeters
    meters = GroupMeters()

    if args.embed:
        from IPython import embed
        embed()

    logger.critical('Building the data loader.')
    validation_dataloader = validation_dataset.make_dataloader(
        args.batch_size,
        shuffle=True,
        drop_last=False,
        nr_workers=args.data_workers)

    model.eval()
    validate_epoch(0, model, validation_dataloader, meters)
    logger.critical(
        meters.format_simple('Validation',
                             {k: v
                              for k, v in meters.avg.items() if v != 0},
                             compressed=False))
    return meters
Exemplo n.º 4
0
  def _test_epoch(self, epoch_size):
    meters = GroupMeters()
    self._prepare_dataset(epoch_size, mode='test')

    def test_func(index):
      message, _ = self._get_result(index, meters, mode='test')
      return message

    tqdm_for(epoch_size, test_func)
    logger.info(meters.format_simple('> Evaluation: ', compressed=False))
    self._dump_meters(meters, 'test')
    return meters
Exemplo n.º 5
0
    def train_epoch(self, data_loader, meters=None):
        if meters is None:
            meters = GroupMeters()

        self._model.train()
        end = time.time()
        for fd in data_loader:
            data_time = time.time() - end
            end = time.time()
            self.train_step(fd, meters=meters)
            step_time = time.time() - end
            end = time.time()
            meters.update({'time/data': data_time, 'time/step': step_time})
        return meters
Exemplo n.º 6
0
    def validate(self, data_loader, metric, meters=None):
        if meters is None:
            meters = GroupMeters()

        self._model.eval()
        end = time.time()
        for fd in data_loader:
            data_time = time.time() - end
            end = time.time()
            self.validate_step(fd, metric, meters=meters)
            step_time = time.time() - end
            end = time.time()
            meters.update({'time/data': data_time, 'time/step': step_time})

        return meters.avg
def main_train(validation_dataset):
    logger.critical('Building the model.')
    model = desc.make_model(args)

    if args.use_gpu:
        model.cuda()
        # Use the customized data parallel if applicable.
        if args.gpu_parallel:
            from jactorch.parallel import JacDataParallel
            # from jactorch.parallel import UserScatteredJacDataParallel as JacDataParallel
            model = JacDataParallel(model, device_ids=args.gpus).cuda()
        # Disable the cudnn benchmark.
        cudnn.benchmark = False

    trainer = TrainerEnv(model, None)

    if args.load:
        if trainer.load_weights(args.load):
            logger.critical(
                'Loaded weights from pretrained model: "{}".'.format(
                    args.load))

        from jacinle.utils.meter import GroupMeters
        meters = GroupMeters()

    logger.critical('Building the data loader.')
    validation_dataloader = validation_dataset.make_dataloader(
        args.batch_size,
        shuffle=False,
        drop_last=False,
        nr_workers=args.data_workers)

    meters.reset()
    model.eval()

    if not os.path.isdir(args.output_attr_path):
        os.makedirs(args.output_attr_path)
    validate_attribute(model, validation_dataloader, meters, args.setname,
                       logger, args.output_attr_path)
    logger.critical(
        meters.format_simple(args.setname,
                             {k: v
                              for k, v in meters.avg.items() if v != 0},
                             compressed=False))
    return meters
Exemplo n.º 8
0
def main():
    scenes = io.load_json(args.scene_json)['scenes']
    preds = io.load(args.preds_json)
    if isinstance(preds, dict):
        preds = list(preds.values())
    if False:
        preds = [transpose_scene(s) for s in preds]
    meter = GroupMeters()

    flattened_objs = [o for s in scenes for o in s['objects']]
    flattened_preds = {
        k: np.concatenate([np.array(p[k]) for p in preds], axis=0)
        for k in preds[0]
    }

    for k, preds in flattened_preds.items():
        kk = def_.word2lemma.get(k, k)
        for i, o in tqdm_gofor(flattened_objs,
                               desc='{}(lemma: {})'.format(k, kk),
                               leave=False):
            meter.update(
                'acc', (preds[i] > 0) == (kk == o[def_.concept2attribute[kk]]))
            meter.update(
                f'acc/{k}',
                (preds[i] > 0) == (kk == o[def_.concept2attribute[kk]]))
    print(meter.format_simple('Results:', compressed=False))
Exemplo n.º 9
0
  def _train_epoch(self, epoch_size):
    model = self.model
    meters = GroupMeters()
    self._prepare_dataset(epoch_size, mode='train')

    def train_func(index):
      model.eval()
      feed_dict = self._get_train_data(index, meters)
      model.train()
      message, _ = self._train_step(feed_dict, meters)
      return message

    # For $epoch_size times, do train_func with tqdm progress bar.
    tqdm_for(epoch_size, train_func)
    logger.info(
        meters.format_simple(
            '> Train Epoch {:5d}: '.format(self.current_epoch),
            compressed=False))
    self._dump_meters(meters, 'train')
    return meters
Exemplo n.º 10
0
    def train(self,
              data_loader,
              nr_epochs,
              verbose=True,
              meters=None,
              early_stop=None,
              print_interval=1):
        if meters is None:
            meters = GroupMeters()

        for epoch in range(1, 1 + nr_epochs):
            meters.reset()
            self.train_epoch(data_loader, meters=meters)
            if verbose and epoch % print_interval == 0:
                caption = 'Epoch: {}:'.format(epoch)
                logger.info(meters.format_simple(caption))
            if early_stop is not None:
                flag = early_stop(self._model)
                if flag:
                    break
Exemplo n.º 11
0
def main():
    scenes = io.load_json(args.scene_json)['scenes']
    preds = io.load(args.preds_json)
    if isinstance(preds, dict):
        preds = list(preds.values())
    if False:
        preds = [transpose_scene(s) for s in preds]
    scenes = scenes[:1000]
    preds = preds[:1000]

    flattened_objs = [o for s in scenes for o in s['objects']]
    flattened_preds = {
        k: np.concatenate([np.array(p[k]) for p in preds], axis=0)
        for k in preds[0]
    }
    meter = GroupMeters()

    for i, obj in tqdm_gofor(flattened_objs, mininterval=0.5):
        test(i, flattened_objs, flattened_preds, meter)

    print(meter.format_simple('Results:', compressed=False))
Exemplo n.º 12
0
def main():
    initialize_dataset(args.dataset)
    build_symbolic_dataset = get_symbolic_dataset_builder(args.dataset)
    dataset = build_symbolic_dataset(args)
    dataloader = dataset.make_dataloader(32, False, False, nr_workers=4)
    meters = GroupMeters()

    for idx, feed_dict in tqdm_gofor(dataloader):
        feed_dict = GView(feed_dict)

        for i, (p, s, gt) in enumerate(
                zip(feed_dict.program_seq, feed_dict.scene, feed_dict.answer)):
            _, pred = execute_program(p, s)

            if pred[0] == 'error':
                raise pred[1]

            if pred[1] != gt:
                print(p)
                print(s)

                from IPython import embed
                embed()
                from sys import exit
                exit()

            meters.update('accuracy', pred[1] == gt)
        get_current_tqdm().set_description(
            meters.format_simple('Exec:', 'val', compressed=True))

    logger.critical(
        meters.format_simple('Symbolic execution test:',
                             'avg',
                             compressed=False))
Exemplo n.º 13
0
def ref_epoch(coach, prepare_fn, recording, ref_dataset):
    prepare_fn()
    model = coach.model
    recording.reset()

    preds = get_preds(model, ref_dataset, coach.logger)

    meter = GroupMeters()

    for image_id, pred in coach.logger.tqdm(preds.items(), mininterval=0.5):
        scene = ref_dataset.sceneGraphs[image_id]
        for j in range(len(scene['objects'])):
            test(j, list(scene['objects'].values()), pred, meter)

    recording.record(meter.avg)
Exemplo n.º 14
0
def main():
    logger.critical('Loading the word embedding.')
    vocab, word_embeddings = load_word_embedding(args.vse)

    logger.critical('Building up the model.')
    model = CompletionModel(word_embeddings)
    if args.use_gpu:
        model.cuda()
    # Disable the cudnn benchmark.
    model.eval()
    cudnn.benchmark = False

    logger.critical('Loading the dataset.')

    dev_dataset = CompletionDataset(vocab, pjoin(args.data_dir, args.dev_img), pjoin(args.data_dir, args.dev_cap), mode=args.mode)
    test_dataset = CompletionDataset(vocab, pjoin(args.data_dir, args.test_img), pjoin(args.data_dir, args.test_cap), mode=args.mode)

    logger.critical('Building up the data loader.')
    dev_dataloader = make_dataloader(dev_dataset, num_workers=args.data_workers, batch_size=64, shuffle=False, drop_last=False, pin_memory=True)
    test_dataloader = make_dataloader(test_dataset, num_workers=args.data_workers, batch_size=64, shuffle=False, drop_last=False, pin_memory=True)

    for epoch_id in range(1, 11):
        load_weights(model, pjoin(args.load, 'epoch_{}.pth'.format(epoch_id)))

        for loader in [dev_dataloader, test_dataloader]:
            meters = GroupMeters()

            end = time.time()
            with tqdm_pbar(total=len(loader), leave=False) as pbar:
                for i, data in enumerate(loader):
                    feed_dict = data
                    feed_dict = mark_volatile(feed_dict)

                    if args.use_gpu:
                        feed_dict = async_copy_to(feed_dict, 0)

                    data_time = time.time() - end; end = time.time()

                    output_dict = model(feed_dict)
                    output_dict = as_numpy(output_dict)

                    gpu_time = time.time() - end;  end = time.time()

                    meters.update({k: float(v) for k, v in output_dict.items() if k.startswith('top')}, n=len(feed_dict['image']))
                    meters.update({'time/data': data_time, 'time/gpu': gpu_time})

                    pbar.set_description(format_meters('sentid={}'.format(i), meters.val, '{}={:.4f}', ', '))
                    pbar.update()

                    end = time.time()

            print(epoch_id, sorted(meters.avg.items()))
Exemplo n.º 15
0
 def _mining_epoch(self, mining_epoch_size, mining_dataset_size):
     """Take exam, collect and update positive dataset and negative dataset"""
     pos_data = RandomlyIterDataset()
     neg_data = RandomlyIterDataset()
     self.model.eval()
     meters_deter = GroupMeters()
     meters_stoch = GroupMeters()
     disable_pbar = False
     if os.getenv("ONCLUSTER") is not None:
         disable_pbar = True
     with tqdm_pbar(total=mining_epoch_size, disable=disable_pbar) as pbar:
         for i in range(mining_epoch_size):
             if i % 2 == 0:
                 message, result = self._get_result(i,
                                                    meters_deter,
                                                    mode='mining-deter')
             else:
                 message, result = self._get_result(i,
                                                    meters_stoch,
                                                    mode='mining-stoch')
             positive, number, backup = self._extract_info(result)
             dataset = pos_data if positive else neg_data
             if dataset.size < mining_dataset_size:
                 dataset.append((number, backup))
             pbar.set_description(message)
             pbar.update()
             # When both positive and negative dataset are full, break.
             if pos_data.size >= mining_dataset_size and \
                     neg_data.size >= mining_dataset_size:
                 break
     logger.info(
         meters_deter.format_simple('> Mining (deter): ', compressed=False))
     logger.info(
         meters_stoch.format_simple('> Mining (stoch): ', compressed=False))
     meters = self.best_meters(meters_deter, meters_stoch)
     self._inherit_neg_data(neg_data, self.neg_data, meters,
                            mining_dataset_size)
     self.pos_data = pos_data
     self.neg_data = neg_data
     self._dump_meters(meters_deter, 'mining-deter')
     self._dump_meters(meters_stoch, 'mining-stoch')
     return meters, meters_deter, meters_stoch
Exemplo n.º 16
0
def main_train(
    train_dataset,
    validation_dataset,
    test_dataset=None,
    prototype_dataset=None,
    one_shot_dataset=None,
):
    logger.critical("Building the model.")
    model = desc.make_model(args, train_dataset.unwrapped.vocab)

    if args.use_gpu:
        model.cuda()
        # Use the customized data parallel if applicable.
        if args.gpu_parallel:
            from jactorch.parallel import JacDataParallel

            # from jactorch.parallel import UserScatteredJacDataParallel as JacDataParallel
            model = JacDataParallel(model, device_ids=args.gpus).cuda()
        # Disable the cudnn benchmark.
        cudnn.benchmark = False

    if hasattr(desc, "make_optimizer"):
        logger.critical("Building customized optimizer.")
        optimizer = desc.make_optimizer(model, args.lr)
    else:
        from jactorch.optim import AdamW

        trainable_parameters = filter(lambda x: x.requires_grad,
                                      model.parameters())
        optimizer = AdamW(trainable_parameters,
                          args.lr,
                          weight_decay=configs.train.weight_decay)

    if args.acc_grad > 1:
        from jactorch.optim import AccumGrad

        optimizer = AccumGrad(optimizer, args.acc_grad)
        logger.warning(
            "Use accumulated grad={:d}, effective iterations per epoch={:d}.".
            format(args.acc_grad, int(args.iters_per_epoch / args.acc_grad)))

    trainer = TrainerEnv(model, optimizer)

    if args.resume:
        extra = trainer.load_checkpoint(args.resume)
        if extra:
            args.start_epoch = extra["epoch"]
            logger.critical("Resume from epoch {}.".format(args.start_epoch))
    elif args.load:
        if trainer.load_weights(args.load):
            logger.critical(
                'Loaded weights from pretrained model: "{}".'.format(
                    args.load))

    if args.use_tb and not args.debug:
        from jactorch.train.tb import TBLogger, TBGroupMeters

        tb_logger = TBLogger(args.tb_dir)
        meters = TBGroupMeters(tb_logger)
        logger.critical('Writing tensorboard logs to: "{}".'.format(
            args.tb_dir))
    else:
        from jacinle.utils.meter import GroupMeters

        meters = GroupMeters()

    if not args.debug:
        logger.critical('Writing meter logs to file: "{}".'.format(
            args.meter_file))

    if args.clip_grad:
        logger.info("Registering the clip_grad hook: {}.".format(
            args.clip_grad))

        def clip_grad(self, loss):
            from torch.nn.utils import clip_grad_norm_

            clip_grad_norm_(self.model.parameters(), max_norm=args.clip_grad)

        trainer.register_event("backward:after", clip_grad)

    if hasattr(desc, "customize_trainer"):
        desc.customize_trainer(trainer)

    if args.embed:
        from IPython import embed

        embed()

    logger.critical("Building the data loader.")
    validation_dataloader = validation_dataset.make_dataloader(
        args.batch_size,
        shuffle=False,
        drop_last=False,
        nr_workers=args.data_workers)
    if test_dataset is not None:
        #     test_dataloader = {
        #         dataset: test_dataset[dataset].make_dataloader(
        #             args.batch_size,
        #             shuffle=False,
        #             drop_last=False,
        #             nr_workers=args.data_workers,
        #         )
        #         for dataset in test_dataset
        #     }

        # if args.evaluate:
        #     meters.reset()
        #     model.eval()
        #     validate_epoch(args.start_epoch, trainer, validation_dataloader, meters)
        #     if test_dataset is not None:
        #         for dataloader in test_dataloader:
        #             validate_epoch(
        #                 args.start_epoch,
        #                 trainer,
        #                 test_dataloader[dataloader],
        #                 meters,
        #                 meter_prefix=dataloader,
        #             )
        # logger.critical(
        #     meters.format_simple(
        #         "Validation",
        #         {k: v for k, v in meters.avg.items() if v != 0},
        #         compressed=False,
        #     )
        # )
        main_one_shot(
            prototype_dataset,
            one_shot_dataset,
            model,
            args.start_epoch,
            trainer,
            meters,
            args.batch_size,
        )
        if not args.debug:
            meters.dump(args.meter_file)

        return meters

    # assert args.curriculum == 'off', 'Unimplemented feature: curriculum mode {}.'.format(args.curriculum)
    curriculum_strategy = [
        (0, 3, 4),
        (5, 3, 6),
        (10, 3, 8),
        (15, 4, 8),
        (25, 4, 12),
        (35, 5, 12),
        (45, 6, 12),
        (55, 7, 16),
        (65, 8, 20),
        (75, 9, 22),
        (90, 10, 25),
        (1e9, None, None),
    ]

    # trainer.register_event('backward:after', backward_check_nan)
    # args.curriculum = "off"

    for epoch in range(args.start_epoch + 1, args.epochs + 1):
        meters.reset()

        model.train()

        this_train_dataset = train_dataset
        if args.curriculum != "off":
            for si, s in enumerate(curriculum_strategy):
                if curriculum_strategy[si][0] < epoch <= curriculum_strategy[
                        si + 1][0]:
                    max_scene_size, max_program_size = s[1:]
                    if args.curriculum in ("scene", "all"):
                        this_train_dataset = this_train_dataset.filter_scene_size(
                            max_scene_size)
                    if args.curriculum in ("program", "all"):
                        this_train_dataset = this_train_dataset.filter_program_size_raw(
                            max_program_size)
                    logger.critical(
                        "Building the data loader. Curriculum = {}/{}, length = {}."
                        .format(*s[1:], len(this_train_dataset)))
                    break

        train_dataloader = this_train_dataset.make_dataloader(
            args.batch_size,
            shuffle=True,
            drop_last=True,
            nr_workers=args.data_workers)

        for enum_id in range(args.enums_per_epoch):
            train_epoch(epoch, trainer, train_dataloader, meters)

        if epoch % args.validation_interval == 0:
            model.eval()
            validate_epoch(epoch, trainer, validation_dataloader, meters)

        if not args.debug:
            meters.dump(args.meter_file)

        logger.critical(
            meters.format_simple(
                "Epoch = {}".format(epoch),
                {
                    k: v
                    for k, v in meters.avg.items()
                    if epoch % args.validation_interval == 0
                    or not (k.startswith("validation") or k.startswith("test"))
                },
                compressed=False,
            ))

        if epoch % args.save_interval == 0 and not args.debug:
            fname = osp.join(args.ckpt_dir, "epoch_{}.pth".format(epoch))
            trainer.save_checkpoint(
                fname, dict(epoch=epoch, meta_file=args.meta_file))

        if epoch > int(args.epochs * 0.6):
            trainer.set_learning_rate(args.lr * 0.1)

    if test_dataset is not None:
        model.eval()
        for dataloader in test_dataloader:
            validate_epoch(
                epoch,
                trainer,
                test_dataloader[dataloader],
                meters,
                meter_prefix=dataloader,
            )
        if not args.debug:
            meters.dump(args.meter_file)
    main_one_shot(
        prototype_dataset,
        one_shot_dataset,
        model,
        epoch,
        trainer,
        meters,
        args.batch_size,
    )
Exemplo n.º 17
0
def main_train(train_dataset, validation_dataset, extra_dataset=None):
    logger.critical('Building the model.')
    model = desc.make_model(args, train_dataset.unwrapped.vocab)

    if args.use_gpu:
        model.cuda()
        # Use the customized data parallel if applicable.
        if args.gpu_parallel:
            from jactorch.parallel import JacDataParallel
            # from jactorch.parallel import UserScatteredJacDataParallel as JacDataParallel
            model = JacDataParallel(model, device_ids=args.gpus).cuda()
        # Disable the cudnn benchmark.
        cudnn.benchmark = False

    if hasattr(desc, 'make_optimizer'):
        logger.critical('Building customized optimizer.')
        optimizer = desc.make_optimizer(model, args.lr)
    else:
        from jactorch.optim import AdamW
        from torch.optim import RMSProp
        trainable_parameters_monet = [x for name, x in model.named_parameters() if x.requires_grad and 'monet' in name 
    x.requires_grad and 'monet' in name]
        trainable_parameters_nscl = [x for x in model.parameters if x.requires_grad and x not in trainable_parameters_monet]

        optimizer_monet = RMSProp(trainable_parameters_monet)
        optimizer_nscl = AdamW(trainable_parameters, args.lr, weight_decay=configs.train.weight_decay)

    if args.acc_grad > 1:
        from jactorch.optim import AccumGrad
        optimizer = AccumGrad(optimizer, args.acc_grad)
        logger.warning('Use accumulated grad={:d}, effective iterations per epoch={:d}.'.format(args.acc_grad, int(args.iters_per_epoch / args.acc_grad)))

    trainer = TrainerEnv(model, optimizer)

    if args.resume:
        extra = trainer.load_checkpoint(args.resume)
        if extra:
            args.start_epoch = extra['epoch']
            logger.critical('Resume from epoch {}.'.format(args.start_epoch))
    elif args.load:
        if trainer.load_weights(args.load):
            logger.critical('Loaded weights from pretrained model: "{}".'.format(args.load))

    if args.use_tb and not args.debug:
        from jactorch.train.tb import TBLogger, TBGroupMeters
        tb_logger = TBLogger(args.tb_dir)
        meters = TBGroupMeters(tb_logger)
        logger.critical('Writing tensorboard logs to: "{}".'.format(args.tb_dir))
    else:
        from jacinle.utils.meter import GroupMeters
        meters = GroupMeters()

    if not args.debug:
        logger.critical('Writing meter logs to file: "{}".'.format(args.meter_file))

    if args.clip_grad:
        logger.info('Registering the clip_grad hook: {}.'.format(args.clip_grad))
        def clip_grad(self, loss):
            from torch.nn.utils import clip_grad_norm_
            clip_grad_norm_(self.model.parameters(), max_norm=args.clip_grad)
        trainer.register_event('backward:after', clip_grad)

    if hasattr(desc, 'customize_trainer'):
        desc.customize_trainer(trainer)

    if args.embed:
        from IPython import embed; embed()

    logger.critical('Building the data loader.')
    validation_dataloader = validation_dataset.make_dataloader(args.batch_size, shuffle=False, drop_last=False, nr_workers=args.data_workers)
    if extra_dataset is not None:
        extra_dataloader = extra_dataset.make_dataloader(args.batch_size, shuffle=False, drop_last=False, nr_workers=args.data_workers)

    if args.evaluate:
        meters.reset()
        model.eval()
        validate_epoch(0, trainer, validation_dataloader, meters)
        if extra_dataset is not None:
            validate_epoch(0, trainer, extra_dataloader, meters, meter_prefix='validation_extra')
        logger.critical(meters.format_simple('Validation', {k: v for k, v in meters.avg.items() if v != 0}, compressed=False))
        return meters

    # assert args.curriculum == 'off', 'Unimplemented feature: curriculum mode {}.'.format(args.curriculum)
    curriculum_strategy = [
        (0, 3, 4),
        (5, 3, 6),
        (10, 3, 8),
        (15, 4, 8),
        (25, 4, 12),
        (35, 5, 12),
        (45, 6, 12),
        (55, 7, 16),
        (65, 8, 20),
        (75, 9, 22),
        (90, 10, 25),
        (1e9, None, None)
    ]

    # trainer.register_event('backward:after', backward_check_nan)

    for epoch in range(args.start_epoch + 1, args.epochs + 1):
        meters.reset()

        model.train()

        this_train_dataset = train_dataset
        if args.curriculum != 'off':
            for si, s in enumerate(curriculum_strategy):
                if curriculum_strategy[si][0] < epoch <= curriculum_strategy[si + 1][0]:
                    max_scene_size, max_program_size = s[1:]
                    if args.curriculum in ('scene', 'all'):
                        this_train_dataset = this_train_dataset.filter_scene_size(max_scene_size)
                    if args.curriculum in ('program', 'all'):
                        this_train_dataset = this_train_dataset.filter_program_size_raw(max_program_size)
                    logger.critical('Building the data loader. Curriculum = {}/{}, length = {}.'.format(*s[1:], len(this_train_dataset)))
                    break

        train_dataloader = this_train_dataset.make_dataloader(args.batch_size, shuffle=True, drop_last=True, nr_workers=args.data_workers)

        for enum_id in range(args.enums_per_epoch):
            train_epoch(epoch, trainer, train_dataloader, meters)

        if epoch % args.validation_interval == 0:
            model.eval()
            validate_epoch(epoch, trainer, validation_dataloader, meters)

        if not args.debug:
            meters.dump(args.meter_file)

        logger.critical(meters.format_simple(
            'Epoch = {}'.format(epoch),
            {k: v for k, v in meters.avg.items() if epoch % args.validation_interval == 0 or not k.startswith('validation')},
            compressed=False
        ))

        if epoch % args.save_interval == 0 and not args.debug:
            fname = osp.join(args.ckpt_dir, 'epoch_{}.pth'.format(epoch))
            trainer.save_checkpoint(fname, dict(epoch=epoch, meta_file=args.meta_file))

        if epoch > int(args.epochs * 0.6):
            trainer.set_learning_rate(args.lr * 0.1)
Exemplo n.º 18
0
    return trainer.early_stopped, trainer.test()


if __name__ == '__main__':
    stats = []
    nr_graduated = 0

    for i in range(args.runs):
        graduated, test_meters = main(i)
        logger.info('run {}'.format(i + 1))

        if test_meters is not None:
            for j, meters in enumerate(test_meters):
                if len(stats) <= j:
                    stats.append(GroupMeters())
                stats[j].update(number=meters.avg['number'],
                                test_acc=meters.avg['accuracy'])

            for meters in stats:
                logger.info('number {}, test_acc {}'.format(
                    meters.avg['number'], meters.avg['test_acc']))

        if not args.test_only:
            nr_graduated += int(graduated)
            logger.info('graduate_ratio {}'.format(nr_graduated / (i + 1)))
            if graduated:
                for j, meters in enumerate(test_meters):
                    stats[j].update(grad_test_acc=meters.avg['accuracy'])
            if nr_graduated > 0:
                for meters in stats:
Exemplo n.º 19
0
def main():
    # directories
    if not args.debug:
        args.dump_dir = ensure_path(osp.join('dumps', args.series_name, args.desc_name, args.run_name))
        args.ckpt_dir = ensure_path(osp.join(args.dump_dir, 'checkpoints'))
        args.vis_dir = ensure_path(osp.join(args.dump_dir, 'visualizations'))
        args.meta_file = osp.join(args.dump_dir, 'metainfo.json')
        args.log_file = osp.join(args.dump_dir, 'log.log')
        args.meter_file = osp.join(args.dump_dir, 'meter.json')

        # Initialize the tensorboard.
        if args.use_tb:
            args.tb_dir = ensure_path(osp.join(args.dump_dir, 'tensorboard'))
        else:
            args.tb_dir = None

    if not args.debug:
        logger.critical('Writing logs to file: "{}".'.format(args.log_file))
        set_output_file(args.log_file)

        logger.critical('Writing metainfo to file: "{}".'.format(args.meta_file))
        with open(args.meta_file, 'w') as f:
            f.write(dump_metainfo(args=args.__dict__, configs=configs))

    if args.debug and args.use_tb:
        logger.warning('Disabling the tensorboard in the debug mode.')
        args.use_tb = False
    if args.evaluate and args.use_tb:
        logger.warning('Disabling the tensorboard in the evaluation mode.')
        args.use_tb = False

    # TODO(Jiayuan Mao @ 04/23): load the dataset.
    logger.critical('Loading the dataset.')
    train_dataset = None
    validation_dataset = None
    # configs.validate_dataset_compatibility(train_dataset)

    # TODO(Jiayuan Mao @ 04/23): build the model.
    logger.critical('Building the model.')
    model = desc.make_model(args)

    if args.use_gpu:
        model.cuda()
        # Use the customized data parallel if applicable.
        if args.gpu_parallel:
            from jactorch.parallel import JacDataParallel
            # Set user_scattered because we will add a multi GPU wrapper to the dataloader. See below.
            model = JacDataParallel(model, device_ids=args.gpus, user_scattered=True).cuda()
        # TODO(Jiayuan Mao @ 04/23): disable the cudnn benchmark.
        # Disable the cudnn benchmark.
        cudnn.benchmark = False

    if hasattr(desc, 'make_optimizer'):
        logger.critical('Building customized optimizer.')
        optimizer = desc.make_optimizer(model, args.lr)
    else:
        from jactorch.optim import AdamW
        # TODO(Jiayuan Mao @ 04/23): set the default optimizer.
        trainable_parameters = filter(lambda x: x.requires_grad, model.parameters())
        optimizer = AdamW(trainable_parameters, args.lr, weight_decay=configs.train.weight_decay)

    if args.acc_grad > 1:
        from jactorch.optim import AccumGrad
        optimizer = AccumGrad(optimizer, args.acc_grad)
        logger.warning('Use accumulated grad={:d}, effective iterations per epoch={:d}.'.format(args.acc_grad, int(args.iters_per_epoch / args.acc_grad)))

    trainer = TrainerEnv(model, optimizer)

    if args.resume:
        extra = trainer.load_checkpoint(args.resume)
        if extra:
            args.start_epoch = extra['epoch']
            logger.critical('Resume from epoch {}.'.format(args.start_epoch))
    elif args.load:
        if trainer.load_weights(args.load):
            logger.critical('Loaded weights from pretrained model: "{}".'.format(args.load))

    if args.use_tb:
        from jactorch.train.tb import TBLogger, TBGroupMeters
        tb_logger = TBLogger(args.tb_dir)
        meters = TBGroupMeters(tb_logger)
        logger.critical('Writing tensorboard logs to: "{}".'.format(args.tb_dir))
    else:
        from jacinle.utils.meter import GroupMeters
        meters = GroupMeters()

    if not args.debug:
        logger.critical('Writing metainfo to file: "{}".'.format(args.meta_file))
        with open(args.meta_file, 'w') as f:
            f.write(dump_metainfo(args=args.__dict__, configs=configs))
        logger.critical('Writing meter logs to file: "{}".'.format(args.meter_file))

        logger.critical('Initializing MLDash.')
        mldash.init(
            desc_name=args.series_name + '/' + args.desc_name,
            expr_name=args.expr,
            run_name=args.run_name,
            args=args,
            highlight_args=parser,
            configs=configs,
        )
        mldash.update(metainfo_file=args.meta_file, log_file=args.log_file, meter_file=args.meter_file, tb_dir=args.tb_dir)

    if args.embed:
        from IPython import embed; embed()

    if hasattr(desc, 'customize_trainer'):
        desc.customize_trainer(trainer)

    # TODO(Jiayuan Mao @ 04/23): make the data loader.
    logger.critical('Building the data loader.')
    train_dataloader = train_dataset.make_dataloader(args.batch_size, shuffle=True, drop_last=True, nr_workers=args.data_workers)
    validation_dataloader = validation_dataset.make_dataloader(args.batch_size, shuffle=False, drop_last=False, nr_workers=args.data_workers)

    if args.use_gpu and args.gpu_parallel:
        from jactorch.data.dataloader import JacDataLoaderMultiGPUWrapper
        train_dataloader = JacDataLoaderMultiGPUWrapper(train_dataloader, args.gpus)
        validation_dataloader = JacDataLoaderMultiGPUWrapper(validation_dataloader, args.gpus)

    if args.evaluate:
        epoch = 0

        model.eval()
        validate_epoch(epoch, trainer, validation_dataloader, meters)

        if not args.debug:
            meters.dump(args.meter_file)

        logger.critical(meters.format_simple('Epoch = {}'.format(epoch), compressed=False))
        return

    for epoch in range(args.start_epoch + 1, args.epochs + 1):
        meters.reset()

        model.train()
        train_epoch(epoch, trainer, train_dataloader, meters)

        if args.validation_interval > 0 and epoch % args.validation_interval == 0:
            model.eval()
            with torch.no_grad():
                validate_epoch(epoch, trainer, validation_dataloader, meters)

        if not args.debug:
            meters.dump(args.meter_file)

        # TODO(Jiayuan Mao @ 02/15): config the MLDash.
        if not args.debug:
            mldash.log_metric('epoch', epoch, desc=False, expr=False)
            for key, value in meters.items():
                if key.startswith('loss') or key.startswith('validation/loss'):
                    mldash.log_metric_min(key, value.avg)
            for key, value in meters.items():
                if key.startswith('acc') or key.startswith('validation/acc'):
                    mldash.log_metric_max(key, value.avg)

        logger.critical(meters.format_simple('Epoch = {}'.format(epoch), compressed=False))

        if not args.debug:
            if epoch % args.save_interval == 0:
                fname = osp.join(args.ckpt_dir, 'epoch_{}.pth'.format(epoch))
                trainer.save_checkpoint(fname, dict(epoch=epoch, meta_file=args.meta_file))
Exemplo n.º 20
0
def main():
    # directories
    if not args.debug:
        args.dump_dir = ensure_path(
            osp.join('dumps', args.series_name, args.desc_name))
        args.ckpt_dir = ensure_path(osp.join(args.dump_dir, 'checkpoints'))
        args.meta_dir = ensure_path(osp.join(args.dump_dir, 'meta'))
        args.meta_file = osp.join(args.meta_dir, args.run_name + '.json')
        args.log_file = osp.join(args.meta_dir, args.run_name + '.log')
        args.meter_file = osp.join(args.meta_dir,
                                   args.run_name + '.meter.json')

    if not args.debug:
        logger.critical('Writing logs to file: "{}".'.format(args.log_file))
        set_output_file(args.log_file)

        logger.critical('Writing metainfo to file: "{}".'.format(
            args.meta_file))
        with open(args.meta_file, 'w') as f:
            f.write(dump_metainfo(args=args.__dict__, configs=configs))
    else:
        if args.use_tb:
            logger.warning(
                'Disabling the tensorboard in the debug mode.'.format(
                    args.meta_file))
            args.use_tb = False

    # TODO(Jiayuan Mao @ 04/23): load the dataset.
    logger.critical('Loading the dataset.')
    validation_dataset = None
    # configs.validate_dataset_compatibility(train_dataset)

    # TODO(Jiayuan Mao @ 04/23): build the model.
    logger.critical('Building the model.')
    model = desc.make_model(args)

    if args.use_gpu:
        model.cuda()
        # Use the customized data parallel if applicable.
        if args.gpu_parallel:
            from jactorch.parallel import JacDataParallel
            # from jactorch.parallel import UserScatteredJacDataParallel as JacDataParallel
            model = JacDataParallel(model, device_ids=args.gpus).cuda()
        # TODO(Jiayuan Mao @ 04/23): disable the cudnn benchmark.
        # Disable the cudnn benchmark.
        cudnn.benchmark = False

    if load_weights(model, args.load):
        logger.critical('Loaded weights from pretrained model: "{}".'.format(
            args.load))

    if args.use_tb:
        from jactorch.train.tb import TBLogger, TBGroupMeters
        tb_logger = TBLogger(args.tb_dir)
        meters = TBGroupMeters(tb_logger)
        logger.critical('Writing tensorboard logs to: "{}".'.format(
            args.tb_dir))
    else:
        from jacinle.utils.meter import GroupMeters
        meters = GroupMeters()

    if not args.debug:
        logger.critical('Writing meter logs to file: "{}".'.format(
            args.meter_file))

    if args.embed:
        from IPython import embed
        embed()

    # TODO(Jiayuan Mao @ 04/23): make the data loader.
    logger.critical('Building the data loader.')
    validation_dataloader = validation_dataset.make_dataloader(
        args.batch_size,
        shuffle=False,
        drop_last=False,
        nr_workers=args.data_workers)

    model.eval()
    validate_epoch(model, validation_dataloader, meters)

    if not args.debug:
        meters.dump(args.meter_file)

    logger.critical(meters.format_simple('Test', compressed=False))
Exemplo n.º 21
0
def main():
    logger.critical('Writing metainfo to file: {}.'.format(args.meta_file))

    with open(args.meta_file, 'w') as f:
        f.write(dump_metainfo(args=args.__dict__))

    logger.critical('Loading the word embedding.')
    vocab, word_embeddings = load_word_embedding(args.vse)

    logger.critical('Building up the model.')
    model = CompletionModel(word_embeddings)
    if args.use_gpu:
        model.cuda()
        assert not args.gpu_parallel
    # Disable the cudnn benchmark.
    cudnn.benchmark = False

    logger.critical('Loading the dataset.')

    train_dataset = CompletionDataset(vocab, pjoin(args.data_dir, args.train_img), pjoin(args.data_dir, args.train_cap))

    logger.critical('Building up the data loader.')
    train_loader = make_dataloader(train_dataset, num_workers=args.data_workers, batch_size=args.batch_size, shuffle=True, drop_last=True, pin_memory=True)

    # Default optimizer.
    optimizer = AdamW(filter(lambda x: x.requires_grad, model.parameters()), args.lr, weight_decay=args.weight_decay)
    if args.acc_grad > 1:
        optimizer = AccumGrad(optimizer, args.acc_grad)
        logger.warning('Use accumulated grad={:d}, effective iterations per epoch={:d}.'.format(args.acc_grad, int(args.iters_per_epoch / args.acc_grad)))
    trainer = TrainerEnv(model, optimizer)

    if args.resume:
        extra = trainer.load_checkpoint(args.resume)
        if extra:
            args.start_epoch = extra['epoch']
            logger.critical('Resume from epoch {}.'.format(args.start_epoch))
    elif args.load:
        if trainer.load_weights(args.load):
            logger.critical('Loaded weights from pretrained model: {}.'.format(args.load))

    if args.use_tb:
        logger.critical('Writing tensorboard logs to: {}.'.format(args.tb_dir))
        tb_logger = TBLogger(args.tb_dir)
        meters = TBGroupMeters(tb_logger)
    else:
        meters = GroupMeters()

    # Switch to train mode.
    model.train()

    if args.embed:
        from IPython import embed; embed()
        return

    for epoch in range(args.start_epoch + 1, args.epochs + 1):
        train_epoch(epoch, train_loader, trainer, meters)

        logger.critical(format_meters('Epoch = {}'.format(epoch), meters.avg, '\t{} = {:.4f}', '\n'))

        if epoch % args.save_interval == 0:
            fname = osp.join(args.ckpt_dir, 'epoch_{}.pth'.format(epoch))
            trainer.save_checkpoint(fname, dict(epoch=epoch, meta_file=args.meta_file))
Exemplo n.º 22
0
def main_train(train_dataset, validation_dataset, extra_dataset=None):
    logger.critical('Building the model.')
    model = desc.make_model(args)
    if args.version=='v3':
        desc_pred = load_source(args.pred_model_path)
        model.build_temporal_prediction_model(args, desc_pred)
    elif args.version=='v4':
        desc_pred = load_source(args.pred_model_path)
        desc_spatial_pred = load_source(args.pred_spatial_model_path)
        model.build_temporal_prediction_model(args, desc_pred, desc_spatial_pred)

    elif args.version=='v2_1':
        model.make_relation_embedding_for_unseen_events(args) 

    if args.use_gpu:
        model.cuda()
        # Disable the cudnn benchmark.
        cudnn.benchmark = False

    if hasattr(desc, 'make_optimizer'):
        logger.critical('Building customized optimizer.')
        optimizer = desc.make_optimizer(model, args.lr)
    else:
        from jactorch.optim import AdamW
        if args.freeze_learner_flag==1:
            if args.reconstruct_flag:
                parameters = list(model._model_pred.parameters())+list(model._decoder.parameters())
                trainable_parameters = filter(lambda x: x.requires_grad, parameters)
            elif args.version=='v4':
                trainable_parameters = filter(lambda x: x.requires_grad, model._model_pred.parameters())
        else:
            trainable_parameters = filter(lambda x: x.requires_grad, model.parameters())
        optimizer = AdamW(trainable_parameters, args.lr, weight_decay=configs.train.weight_decay)

    if args.acc_grad > 1:
        from jactorch.optim import AccumGrad
        optimizer = AccumGrad(optimizer, args.acc_grad)
        logger.warning('Use accumulated grad={:d}, effective iterations per epoch={:d}.'.format(args.acc_grad, int(args.iters_per_epoch / args.acc_grad)))

    trainer = TrainerEnv(model, optimizer)

    if args.resume:
        extra = trainer.load_checkpoint(args.resume)
        if extra:
            args.start_epoch = extra['epoch']
            logger.critical('Resume from epoch {}.'.format(args.start_epoch))
    elif args.load:
        if trainer.load_weights(args.load):
            logger.critical('Loaded weights from pretrained model: "{}".'.format(args.load))
        if args.version=='v3':
            if args.pretrain_pred_model_path:
                model._model_pred.load_state_dict(torch.load(args.pretrain_pred_model_path))
                logger.critical('Loaded weights from pretrained temporal model: "{}".'.format(args.pretrain_pred_model_path))
        elif args.version=='v4':
            if args.pretrain_pred_spatial_model_path:
                model._model_spatial_pred.load_state_dict(torch.load(args.pretrain_pred_spatial_model_path))
                logger.critical('Loaded spatial models from pretrained temporal model: "{}".'.format(args.pretrain_pred_spatial_model_path))
            if args.pretrain_pred_feature_model_path:
                model._model_pred.load_state_dict(torch.load(args.pretrain_pred_feature_model_path))
                logger.critical('Loaded feature models from pretrained temporal model: "{}".'.format(args.pretrain_pred_feature_model_path))
                #pdb.set_trace()
            if args.pretrain_pred_model_path:
                model._model_pred.load_state_dict(torch.load(args.pretrain_pred_model_path))
                logger.critical('Loaded weights from pretrained temporal model: "{}".'.format(args.pretrain_pred_model_path))
        elif args.version =='v2_1':
            model.reasoning.embedding_relation_future.load_state_dict(model.reasoning.embedding_relation.state_dict())
            model.reasoning.embedding_relation_counterfact.load_state_dict(model.reasoning.embedding_relation.state_dict())
            logger.critical('Copy original relation weights into counterfact and future relation.')
    if args.use_tb and not args.debug:
        from jactorch.train.tb import TBLogger, TBGroupMeters
        tb_logger = TBLogger(args.tb_dir)
        meters = TBGroupMeters(tb_logger)
        logger.critical('Writing tensorboard logs to: "{}".'.format(args.tb_dir))
    else:
        from jacinle.utils.meter import GroupMeters
        meters = GroupMeters()

    if not args.debug:
        logger.critical('Writing meter logs to file: "{}".'.format(args.meter_file))

    if args.clip_grad:
        logger.info('Registering the clip_grad hook: {}.'.format(args.clip_grad))
        def clip_grad(self, loss):
            from torch.nn.utils import clip_grad_norm_
            clip_grad_norm_(self.model.parameters(), max_norm=args.clip_grad)
        trainer.register_event('backward:after', clip_grad)

    if hasattr(desc, 'customize_trainer'):
        desc.customize_trainer(trainer)

    if args.embed:
        from IPython import embed; embed()

    if args.debug:
        shuffle_flag=False
    else:
        shuffle_flag=True

    logger.critical('Building the data loader.')
    validation_dataloader = validation_dataset.make_dataloader(args.batch_size, shuffle=False, drop_last=False, nr_workers=args.data_workers)
    if extra_dataset is not None:
        extra_dataloader = extra_dataset.make_dataloader(args.batch_size, shuffle=False, drop_last=False, nr_workers=args.data_workers)

    if args.evaluate:
        meters.reset()
        model.eval()
        validate_epoch(0, trainer, validation_dataloader, meters)
        if extra_dataset is not None:
            validate_epoch(0, trainer, extra_dataloader, meters, meter_prefix='validation_extra')
        logger.critical(meters.format_simple('Validation', {k: v for k, v in meters.avg.items() if v != 0}, compressed=False))
        return meters


    for epoch in range(args.start_epoch + 1, args.epochs + 1):
        meters.reset()

        model.train()

        this_train_dataset = train_dataset
        train_dataloader = this_train_dataset.make_dataloader(args.batch_size, shuffle=shuffle_flag, drop_last=True, nr_workers=args.data_workers)

        for enum_id in range(args.enums_per_epoch):
            train_epoch(epoch, trainer, train_dataloader, meters)

        if epoch % args.validation_interval == 0:
            model.eval()
            validate_epoch(epoch, trainer, validation_dataloader, meters)

        if not args.debug:
            meters.dump(args.meter_file)

        logger.critical(meters.format_simple(
            'Epoch = {}'.format(epoch),
            {k: v for k, v in meters.avg.items() if epoch % args.validation_interval == 0 or not k.startswith('validation')},
            compressed=False
        ))

        if epoch % args.save_interval == 0 and not args.debug:
            fname = osp.join(args.ckpt_dir, 'epoch_{}.pth'.format(epoch))
            trainer.save_checkpoint(fname, dict(epoch=epoch, meta_file=args.meta_file))

        if epoch > int(args.epochs * 0.6):
            trainer.set_learning_rate(args.lr * 0.1)