Ejemplo n.º 1
0
    def test_type_specific_tree_lstm_cell(self):
        device = get_device()
        fix_seed()

        for i, (x_size, h_size, number_of_children) in enumerate(
                zip(self.x_sizes, self.h_sizes, self.numbers_of_children)):
            with self.subTest(i=i):
                g = _gen_node_with_children(number_of_children)
                g.ndata['x'] = torch.rand(number_of_children + 1, x_size)
                g.ndata['type_id'] = torch.tensor(range(0, number_of_children + 1))
                type_relationship = {
                    (0,): [list(range(1, number_of_children // 2))]
                }

                tree_lstm_cell = TypeSpecificTreeLSTMCell(x_size, h_size, type_relationship)

                h_tree_lstm, c_tree_lstm = tree_lstm_cell(g, device)

                tree_lstm_cell_params = tree_lstm_cell.get_params()
                u_f_indices = [
                    tree_lstm_cell.edge_matrix_id.get((0, i), 0) for i in range(1, number_of_children + 1)
                ]
                tree_lstm_cell_params['u_f'] = tree_lstm_cell_params['u_f'][u_f_indices]
                h_calculated, c_calculated = _calculate_nary_tree_lstm_states(g.ndata['x'], **tree_lstm_cell_params)

                self.assertTrue(
                    torch.allclose(h_tree_lstm, h_calculated, atol=ATOL), msg=f"Unequal hidden state tensors"
                )
                self.assertTrue(
                    torch.allclose(c_tree_lstm, c_calculated, atol=ATOL), msg=f"Unequal memory state tensors"
                )
Ejemplo n.º 2
0
    def test_childsum_tree_lstm_batch(self):
        device = get_device()
        fix_seed()

        x_size = 5
        h_size = 5
        numbers_of_children = [7, 7]

        tree_lstm_types = [EdgeChildSumTreeLSTMCell, NodeChildSumTreeLSTMCell]
        for tree_lstm_type in tree_lstm_types:
            with self.subTest(msg=f"test {tree_lstm_type.__name__} tree lstm cell"):
                tree_lstm_cell = tree_lstm_type(x_size, h_size)

                g1 = _gen_node_with_children(numbers_of_children[0])
                g2 = _gen_node_with_children(numbers_of_children[1])
                g1.ndata['x'] = torch.rand(numbers_of_children[0] + 1, x_size)
                g2.ndata['x'] = torch.rand(numbers_of_children[1] + 1, x_size)
                g = dgl.batch([g1, g2])

                h_tree_lstm, c_tree_lstm = tree_lstm_cell(g, device)

                h1_calculated, c1_calculated = _calculate_childsum_tree_lstm_states(
                    g1.ndata['x'], **tree_lstm_cell.get_params()
                )
                h2_calculated, c2_calculated = _calculate_childsum_tree_lstm_states(
                    g2.ndata['x'], **tree_lstm_cell.get_params()
                )
                h_calculated = torch.cat([h1_calculated, h2_calculated], 0)
                c_calculated = torch.cat([c1_calculated, c2_calculated], 0)

                self.assertTrue(torch.allclose(h_tree_lstm, h_calculated, atol=ATOL), msg=f"Unequal hidden state tensors")
                self.assertTrue(torch.allclose(c_tree_lstm, c_calculated, atol=ATOL), msg=f"Unequal memory state tensors")
Ejemplo n.º 3
0
    def test_subtoken_embedding(self):
        fix_seed()
        device = get_device()
        h_emb = 5
        token_to_id = {
            'token|name|first': 0,
            'token|second': 1,
            'token|third|name': 2
        }
        g = dgl.DGLGraph()
        g.add_nodes(3, {'token_id': torch.tensor([0, 1, 2])})
        subtoken_embedding = SubTokenEmbedding(token_to_id, {}, h_emb)

        embed_weight = torch.zeros(len(subtoken_embedding.subtoken_to_id),
                                   h_emb)
        embed_weight[subtoken_embedding.subtoken_to_id['token'], 0] = 1
        embed_weight[subtoken_embedding.subtoken_to_id['name'], 1] = 1
        embed_weight[subtoken_embedding.subtoken_to_id['first'], 2] = 1
        embed_weight[subtoken_embedding.subtoken_to_id['second'], 3] = 1
        embed_weight[subtoken_embedding.subtoken_to_id['third'], 4] = 1

        subtoken_embedding.subtoken_embedding.weight = torch.nn.Parameter(
            embed_weight, requires_grad=True)

        embed_g = subtoken_embedding(g, device)
        true_embeds = torch.tensor(
            [[1, 1, 1, 0, 0], [1, 0, 0, 1, 0], [1, 1, 0, 0, 1]],
            device=device,
            dtype=torch.float)

        self.assertEqual(
            torch.allclose(true_embeds, embed_g.ndata['token_embeds']), True)
Ejemplo n.º 4
0
def interactive(path_to_function: str, path_to_model: str):
    fix_seed()
    device = get_device()
    print(f"using {device} device")

    # convert function to dot format
    print(f"prepare ast...")
    create_folder(TMP_FOLDER)
    if not build_ast(path_to_function):
        return
    ast_folder = os.path.join(TMP_FOLDER, 'java', 'asts')
    ast = os.listdir(ast_folder)
    if len(ast) == 0:
        print("didn't find any functions in given file")
        return
    if len(ast) > 1:
        print(
            "too many functions in given file, for interactive prediction you need only one"
        )
        return
    dgl_ast = convert_dot_to_dgl(os.path.join(ast_folder, ast[0]))
    ast_desc = pd.read_csv(os.path.join(TMP_FOLDER, 'java', 'description.csv'))
    ast_desc['token'].fillna('NAN', inplace=True)
    with open(vocab_path, 'rb') as pkl_file:
        vocab = pkl_load(pkl_file)
        token_to_id, type_to_id = vocab['token_to_id'], vocab['type_to_id']
    ast_desc = transform_keys(ast_desc, token_to_id, type_to_id)
    batched_graph, labels, paths = prepare_batch(ast_desc, ['ast_0.dot'],
                                                 lambda: [dgl_ast])
    batched_graph = dgl.batch(
        list(
            map(lambda g: dgl.reverse(g, share_ndata=True),
                dgl.unbatch(batched_graph))))

    # load model
    print("loading model..")
    model, _ = load_model(path_to_model, device)
    criterion = nn.CrossEntropyLoss(
        ignore_index=model.decoder.pad_index).to(device)
    info = LearningInfo()

    print("forward pass...")
    batch_info, prediction = eval_on_batch(model, criterion, batched_graph,
                                           labels, device)

    info.accumulate_info(batch_info)
    id_to_sublabel = {v: k for k, v in model.decoder.label_to_id.items()}
    label = ''
    for cur_sublabel in prediction:
        if cur_sublabel.item() == model.decoder.label_to_id[EOS]:
            break
        label += '|' + id_to_sublabel[cur_sublabel.item()]
    label = label[1:]
    print(f"Predicted function name is\n{label}")
    print(
        f"Calculated metrics with respect to '{labels[0]}' name\n{info.get_state_dict()}"
    )
    def load_from_model(cls, config, ConfigClass, ModelClass):
        gpu_ids = list(map(int, config.gpu_ids.split()))
        multi_gpu = (len(gpu_ids) > 1)
        device = get_device(gpu_ids)

        ctrl = cls(config)
        ctrl.device = device
        ctrl.trainer = Trainer(ConfigClass, ModelClass, multi_gpu, device,
                               config.print_step, config.output_model_dir,
                               config.fp16)

        return ctrl
Ejemplo n.º 6
0
 def _test_childsum_tree_lstm_cell(self, tree_lstm_type):
     device = get_device()
     fix_seed()
     for i in range(len(self.x_sizes)):
         x_size, h_size, number_of_children = self.x_sizes[i], self.h_sizes[i], self.numbers_of_children[i]
         with self.subTest(i=i):
             h_equal, c_equal = _test_childsum(
                 tree_lstm_type, x_size, h_size, number_of_children, device
             )
             self.assertTrue(
                 h_equal, msg=f"Unequal hidden state tensors for ({x_size}, {h_size}, {number_of_children}) params"
             )
             self.assertTrue(
                 c_equal, msg=f"Unequal memory state tensors for ({x_size}, {h_size}, {number_of_children}) params"
             )
Ejemplo n.º 7
0
def interactive(path_to_function: str, path_to_model: str):
    fix_seed()
    device = get_device()
    print(f"using {device} device")

    # load model
    print("loading model...")
    checkpoint = torch.load(path_to_model, map_location=device)

    model = Tree2Seq(**checkpoint['configuration']).to(device)
    model.load_state_dict(checkpoint['state_dict'])

    token_to_id = model.token_to_id
    type_to_id = model.type_to_id
    label_to_id = model.label_to_id
    id_to_label = {v: k for k, v in label_to_id.items()}

    # convert function to dgl format
    print("convert function to dgl format...")
    create_folder(TMP_FOLDER)
    build_asts(path_to_function, TMP_FOLDER, ASTMINER_PATH, *ASTMINER_PARAMS)
    project_folder = os.path.join(TMP_FOLDER, 'java')
    convert_project(project_folder, token_to_id, type_to_id, label_to_id, True,
                    True, 5, 6, False, True, '|')

    # load function
    graph, labels = load_graphs(os.path.join(project_folder, 'converted.dgl'))
    labels = labels['labels']
    assert len(labels) == 1, f"found {len('labels')} functions, instead of 1"
    ast = graph[0].reverse(share_ndata=True)
    ast.ndata['token'] = ast.ndata['token'].to(device)
    ast.ndata['type'] = ast.ndata['type'].to(device)
    labels = labels.t().to(device)
    root_indexes = torch.tensor([0], dtype=torch.long)

    # forward pass
    model.eval()
    with torch.no_grad():
        logits = model(ast, root_indexes, labels, device)
    logits = logits[1:]
    prediction = model.predict(logits).reshape(-1)
    sublabels = [id_to_label[label_id.item()] for label_id in prediction]
    label = '|'.join(takewhile(lambda sl: sl != EOS, sublabels))
    print(f"the predicted label is:\n{label}")
Ejemplo n.º 8
0
def evaluate(params: Dict) -> None:
    fix_seed()
    device = get_device()
    print(f"using {device} device")

    evaluation_set = JavaDataset(params['paths']['evaluate'],
                                 params['batch_size'], True)

    model, _ = load_model(params['paths']['model'], device)

    # define loss function
    criterion = nn.CrossEntropyLoss(
        ignore_index=model.decoder.pad_index).to(device)

    # evaluation loop
    print("ok, let's evaluate it")
    eval_epoch_info = evaluate_dataset(evaluation_set, model, criterion,
                                       device)

    print(eval_epoch_info.get_state_dict())
Ejemplo n.º 9
0
def val():
    """Validation."""
    torch.backends.cudnn.benchmark = True

    # model
    model, model_wrapper = mc.get_model()
    ema = mc.setup_ema(model)
    criterion = torch.nn.CrossEntropyLoss(reduction='none').cuda()
    # TODO(meijieru): cal loss on all GPUs instead only `cuda:0` when non
    # distributed

    # check pretrained
    if FLAGS.pretrained:
        checkpoint = torch.load(FLAGS.pretrained,
                                map_location=lambda storage, loc: storage)
        if ema:
            ema.load_state_dict(checkpoint['ema'])
            ema.to(get_device(model))
        model_wrapper.load_state_dict(checkpoint['model'])
        logging.info('Loaded model {}.'.format(FLAGS.pretrained))

    if udist.is_master():
        logging.info(model_wrapper)

    # data
    (train_transforms, val_transforms, test_transforms) = \
        dataflow.data_transforms(FLAGS)
    (train_set, val_set, test_set) = dataflow.dataset(train_transforms,
                                                      val_transforms,
                                                      test_transforms, FLAGS)
    _, calib_loader, _, test_loader = dataflow.data_loader(
        train_set, val_set, test_set, FLAGS)

    if udist.is_master():
        logging.info('Start testing.')
    FLAGS._global_step = 0
    test_meters = mc.get_meters('test')
    validate(0, calib_loader, test_loader, criterion, test_meters,
             model_wrapper, ema, 'test')
    return
Ejemplo n.º 10
0
def evaluate(params: Dict) -> None:
    fix_seed()
    device = get_device()
    print(f"using {device} device")

    checkpoint = torch.load(params['model'], map_location=device)

    print('model initializing...')
    # create model
    model = Tree2Seq(**checkpoint['configuration']).to(device)
    model.load_state_dict(checkpoint['state_dict'])

    evaluation_set = TreeDGLDataset(params['dataset'], params['batch_size'], device, True)

    # define loss function
    criterion = nn.CrossEntropyLoss(ignore_index=model.label_to_id[PAD]).to(device)

    # evaluation loop
    print("ok, let's evaluate it")
    eval_epoch_info = evaluate_on_dataset(evaluation_set, model, criterion)

    print(eval_epoch_info.get_state_dict())
    def init(self, ModelClass):
        '''
        ModelClass: e.g. modelTC
        '''
        gpu_ids = list(map(int, self.config.gpu_ids.split()))
        multi_gpu = (len(gpu_ids) > 1)
        self.device = get_device(gpu_ids)

        if self.config.mission == 'train':
            model_dir = self.config.PTM_model_vocab_dir
        else:
            model_dir = self.config.model_save_dir
        print('init_model', model_dir)
        model = ModelClass.from_pretrained(model_dir)
        print(model)

        if multi_gpu:
            model = torch.nn.DataParallel(model, device_ids=gpu_ids)

        self.trainer = Trainer(model, multi_gpu, self.device,
                               self.config.print_step,
                               self.config.model_save_dir, self.config.fp16)
        self.model = model
Ejemplo n.º 12
0
def train(params: Dict, logging: str) -> None:
    fix_seed()
    device = get_device()
    print(f"using {device} device")

    training_set = JavaDataset(params['paths']['train'], params['batch_size'],
                               True)
    validation_set = JavaDataset(params['paths']['validate'],
                                 params['batch_size'], True)

    with open(params['paths']['vocabulary'], 'rb') as pkl_file:
        vocabulary = pkl_load(pkl_file)
        token_to_id = vocabulary['token_to_id']
        type_to_id = vocabulary['type_to_id']
        label_to_id = vocabulary['label_to_id']

    print('model initializing...')
    is_resumed = 'resume' in params
    if is_resumed:
        # load model
        model, checkpoint = load_model(params['resume'], device)
        start_batch_id = checkpoint['batch_id'] + 1
        configuration = checkpoint['configuration']
    else:
        # create model
        model_factory = ModelFactory(params['embedding'], params['encoder'],
                                     params['decoder'],
                                     params['hidden_states'], token_to_id,
                                     type_to_id, label_to_id)
        model: Tree2Seq = model_factory.construct_model(device)
        configuration = model_factory.save_configuration()
        start_batch_id = 0

    # create optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=params['lr'],
                                 weight_decay=params['weight_decay'])
    # create scheduler
    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer,
        step_size=params['scheduler_step_size'],
        gamma=params['scheduler_gamma'])

    # define loss function
    criterion = nn.CrossEntropyLoss(
        ignore_index=model.decoder.pad_index).to(device)

    # init logging class
    logger = None
    if logging == TerminalLogger.name:
        logger = TerminalLogger(params['checkpoints_folder'])
    elif logging == FileLogger.name:
        logger = FileLogger(params, params['logging_folder'],
                            params['checkpoints_folder'])
    elif logging == WandBLogger.name:
        logger_args = ['treeLSTM', params, model, params['checkpoints_folder']]
        if 'resume_wandb_id' in params:
            logger_args.append(params['resume_wandb_id'])
        logger = WandBLogger(*logger_args)

    # train loop
    print("ok, let's train it")
    for epoch in range(params['n_epochs']):
        train_acc_info = LearningInfo()

        if epoch > 0:
            # specify start batch id only for first epoch
            start_batch_id = 0
        tqdm_batch_iterator = tqdm(range(start_batch_id, len(training_set)),
                                   total=len(training_set))
        tqdm_batch_iterator.update(start_batch_id)
        tqdm_batch_iterator.refresh()

        # iterate over training set
        for batch_id in tqdm_batch_iterator:
            graph, labels = training_set[batch_id]
            graph.ndata['token_id'] = graph.ndata['token_id'].to(device)
            graph.ndata['type_id'] = graph.ndata['type_id'].to(device)
            batch_info = train_on_batch(model, criterion, optimizer, scheduler,
                                        graph, labels, params, device)
            train_acc_info.accumulate_info(batch_info)
            # log current train process
            if is_current_step_match(batch_id, params['logging_step']):
                logger.log(train_acc_info.get_state_dict(), epoch, batch_id)
                train_acc_info = LearningInfo()
            # validate current model
            if is_current_step_match(
                    batch_id, params['evaluation_step']) and batch_id != 0:
                eval_epoch_info = evaluate_dataset(validation_set, model,
                                                   criterion, device)
                logger.log(eval_epoch_info.get_state_dict(), epoch, batch_id,
                           False)
            # save current model
            if is_current_step_match(
                    batch_id, params['checkpoint_step']) and batch_id != 0:
                logger.save_model(model,
                                  f'epoch_{epoch}_batch_{batch_id}.pt',
                                  configuration,
                                  batch_id=batch_id)

        logger.log(train_acc_info.get_state_dict(), epoch, len(training_set))
        eval_epoch_info = evaluate_dataset(validation_set, model, criterion,
                                           device)
        logger.log(eval_epoch_info.get_state_dict(), epoch, len(training_set),
                   False)

        logger.save_model(model, f'epoch_{epoch}.pt', configuration)
Ejemplo n.º 13
0
 def compress_by_mask(self, masks, **kwargs):
     """Regenerate internal compute graph given alive masks."""
     device = get_device(self.pw_bn)
     cu.copmress_inverted_residual_channels(self, masks, **kwargs)
     self.to(device)
Ejemplo n.º 14
0
def train_val_test():
    """Train and val."""
    torch.backends.cudnn.benchmark = True

    # model
    model, model_wrapper = mc.get_model()
    ema = mc.setup_ema(model)
    criterion = torch.nn.CrossEntropyLoss(reduction='none').cuda()
    criterion_smooth = optim.CrossEntropyLabelSmooth(
        FLAGS.model_kwparams['num_classes'],
        FLAGS['label_smoothing'],
        reduction='none').cuda()
    # TODO(meijieru): cal loss on all GPUs instead only `cuda:0` when non
    # distributed

    if FLAGS.get('log_graph_only', False):
        if udist.is_master():
            _input = torch.zeros(1, 3, FLAGS.image_size,
                                 FLAGS.image_size).cuda()
            _input = _input.requires_grad_(True)
            mc.summary_writer.add_graph(model_wrapper, (_input, ),
                                        verbose=True)
        return

    # check pretrained
    if FLAGS.pretrained:
        checkpoint = torch.load(FLAGS.pretrained,
                                map_location=lambda storage, loc: storage)
        if ema:
            ema.load_state_dict(checkpoint['ema'])
            ema.to(get_device(model))
        # update keys from external models
        if isinstance(checkpoint, dict) and 'model' in checkpoint:
            checkpoint = checkpoint['model']
        if (hasattr(FLAGS, 'pretrained_model_remap_keys')
                and FLAGS.pretrained_model_remap_keys):
            new_checkpoint = {}
            new_keys = list(model_wrapper.state_dict().keys())
            old_keys = list(checkpoint.keys())
            for key_new, key_old in zip(new_keys, old_keys):
                new_checkpoint[key_new] = checkpoint[key_old]
                logging.info('remap {} to {}'.format(key_new, key_old))
            checkpoint = new_checkpoint
        model_wrapper.load_state_dict(checkpoint)
        logging.info('Loaded model {}.'.format(FLAGS.pretrained))
    optimizer = optim.get_optimizer(model_wrapper, FLAGS)

    # check resume training
    if FLAGS.resume:
        checkpoint = torch.load(os.path.join(FLAGS.resume,
                                             'latest_checkpoint.pt'),
                                map_location=lambda storage, loc: storage)
        model_wrapper.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        if ema:
            ema.load_state_dict(checkpoint['ema'])
            ema.to(get_device(model))
        last_epoch = checkpoint['last_epoch']
        lr_scheduler = optim.get_lr_scheduler(optimizer, FLAGS)
        lr_scheduler.last_epoch = (last_epoch + 1) * FLAGS._steps_per_epoch
        best_val = extract_item(checkpoint['best_val'])
        train_meters, val_meters = checkpoint['meters']
        FLAGS._global_step = (last_epoch + 1) * FLAGS._steps_per_epoch
        if udist.is_master():
            logging.info('Loaded checkpoint {} at epoch {}.'.format(
                FLAGS.resume, last_epoch))
    else:
        lr_scheduler = optim.get_lr_scheduler(optimizer, FLAGS)
        # last_epoch = lr_scheduler.last_epoch
        last_epoch = -1
        best_val = 1.
        train_meters = mc.get_meters('train')
        val_meters = mc.get_meters('val')
        FLAGS._global_step = 0

    if not FLAGS.resume and udist.is_master():
        logging.info(model_wrapper)
    if FLAGS.profiling:
        if 'gpu' in FLAGS.profiling:
            mc.profiling(model, use_cuda=True)
        if 'cpu' in FLAGS.profiling:
            mc.profiling(model, use_cuda=False)

    # data
    (train_transforms, val_transforms,
     test_transforms) = dataflow.data_transforms(FLAGS)
    (train_set, val_set, test_set) = dataflow.dataset(train_transforms,
                                                      val_transforms,
                                                      test_transforms, FLAGS)
    (train_loader, calib_loader, val_loader,
     test_loader) = dataflow.data_loader(train_set, val_set, test_set, FLAGS)

    if FLAGS.test_only and (test_loader is not None):
        if udist.is_master():
            logging.info('Start testing.')
        test_meters = mc.get_meters('test')
        validate(last_epoch, calib_loader, test_loader, criterion, test_meters,
                 model_wrapper, ema, 'test')
        return

    # already broadcast by AllReduceDistributedDataParallel
    # optimizer load same checkpoint/same initialization

    if udist.is_master():
        logging.info('Start training.')

    for epoch in range(last_epoch + 1, FLAGS.num_epochs):
        # train
        results = run_one_epoch(epoch,
                                train_loader,
                                model_wrapper,
                                criterion_smooth,
                                optimizer,
                                lr_scheduler,
                                ema,
                                train_meters,
                                phase='train')

        # val
        results = validate(epoch, calib_loader, val_loader, criterion,
                           val_meters, model_wrapper, ema, 'val')
        if results['top1_error'] < best_val:
            best_val = results['top1_error']

            if udist.is_master():
                save_status(model_wrapper, optimizer, ema, epoch, best_val,
                            (train_meters, val_meters),
                            os.path.join(FLAGS.log_dir, 'best_model.pt'))
                logging.info(
                    'New best validation top1 error: {:.4f}'.format(best_val))
        if udist.is_master():
            # save latest checkpoint
            save_status(model_wrapper, optimizer, ema, epoch, best_val,
                        (train_meters, val_meters),
                        os.path.join(FLAGS.log_dir, 'latest_checkpoint.pt'))

        wandb.log(
            {
                "Validation Accuracy": 1. - results['top1_error'],
                "Best Validation Accuracy": 1. - best_val
            },
            step=epoch)


# NOTE(meijieru): from scheduler code, should be called after train/val
# use stepwise scheduler instead
# lr_scheduler.step()
    return
Ejemplo n.º 15
0
def train_val_test():
    """Train and val."""
    torch.backends.cudnn.benchmark = True  # For acceleration

    # model
    model, model_wrapper = mc.get_model()
    ema = mc.setup_ema(model)
    criterion = torch.nn.CrossEntropyLoss(reduction='mean').cuda()
    criterion_smooth = optim.CrossEntropyLabelSmooth(
        FLAGS.model_kwparams['num_classes'],
        FLAGS['label_smoothing'],
        reduction='mean').cuda()
    if model.task == 'segmentation':
        criterion = CrossEntropyLoss().cuda()
        criterion_smooth = CrossEntropyLoss().cuda()
    if FLAGS.dataset == 'coco':
        criterion = JointsMSELoss(use_target_weight=True).cuda()
        criterion_smooth = JointsMSELoss(use_target_weight=True).cuda()

    if FLAGS.get('log_graph_only', False):
        if udist.is_master():
            _input = torch.zeros(1, 3, FLAGS.image_size,
                                 FLAGS.image_size).cuda()
            _input = _input.requires_grad_(True)
            if isinstance(model_wrapper,
                          (torch.nn.DataParallel,
                           udist.AllReduceDistributedDataParallel)):
                mc.summary_writer.add_graph(model_wrapper.module, (_input, ),
                                            verbose=True)
            else:
                mc.summary_writer.add_graph(model_wrapper, (_input, ),
                                            verbose=True)
        return

    # check pretrained
    if FLAGS.pretrained:
        checkpoint = torch.load(FLAGS.pretrained,
                                map_location=lambda storage, loc: storage)
        if ema:
            ema.load_state_dict(checkpoint['ema'])
            ema.to(get_device(model))
        # update keys from external models
        if isinstance(checkpoint, dict) and 'model' in checkpoint:
            checkpoint = checkpoint['model']
        if (hasattr(FLAGS, 'pretrained_model_remap_keys')
                and FLAGS.pretrained_model_remap_keys):
            new_checkpoint = {}
            new_keys = list(model_wrapper.state_dict().keys())
            old_keys = list(checkpoint.keys())
            for key_new, key_old in zip(new_keys, old_keys):
                new_checkpoint[key_new] = checkpoint[key_old]
                if udist.is_master():
                    logging.info('remap {} to {}'.format(key_new, key_old))
            checkpoint = new_checkpoint
        model_wrapper.load_state_dict(checkpoint)
        if udist.is_master():
            logging.info('Loaded model {}.'.format(FLAGS.pretrained))
    optimizer = optim.get_optimizer(model_wrapper, FLAGS)

    # check resume training
    if FLAGS.resume:
        checkpoint = torch.load(os.path.join(FLAGS.resume,
                                             'latest_checkpoint.pt'),
                                map_location=lambda storage, loc: storage)
        model_wrapper = checkpoint['model'].cuda()
        model = model_wrapper.module
        # model = checkpoint['model'].module
        optimizer = checkpoint['optimizer']
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.cuda()
        # model_wrapper.load_state_dict(checkpoint['model'])
        # optimizer.load_state_dict(checkpoint['optimizer'])
        if ema:
            # ema.load_state_dict(checkpoint['ema'])
            ema = checkpoint['ema'].cuda()
            ema.to(get_device(model))
        last_epoch = checkpoint['last_epoch']
        lr_scheduler = optim.get_lr_scheduler(optimizer,
                                              FLAGS,
                                              last_epoch=(last_epoch + 1) *
                                              FLAGS._steps_per_epoch)
        lr_scheduler.last_epoch = (last_epoch + 1) * FLAGS._steps_per_epoch
        best_val = extract_item(checkpoint['best_val'])
        train_meters, val_meters = checkpoint['meters']
        FLAGS._global_step = (last_epoch + 1) * FLAGS._steps_per_epoch
        if udist.is_master():
            logging.info('Loaded checkpoint {} at epoch {}.'.format(
                FLAGS.resume, last_epoch))
    else:
        lr_scheduler = optim.get_lr_scheduler(optimizer, FLAGS)
        # last_epoch = lr_scheduler.last_epoch
        last_epoch = -1
        best_val = 1.
        if not FLAGS.distill:
            train_meters = mc.get_meters('train', FLAGS.prune_params['method'])
            val_meters = mc.get_meters('val')
        else:
            train_meters = mc.get_distill_meters('train',
                                                 FLAGS.prune_params['method'])
            val_meters = mc.get_distill_meters('val')
        if FLAGS.model_kwparams.task == 'segmentation':
            best_val = 0.
            if not FLAGS.distill:
                train_meters = mc.get_seg_meters('train',
                                                 FLAGS.prune_params['method'])
                val_meters = mc.get_seg_meters('val')
            else:
                train_meters = mc.get_seg_distill_meters(
                    'train', FLAGS.prune_params['method'])
                val_meters = mc.get_seg_distill_meters('val')
        FLAGS._global_step = 0

    if not FLAGS.resume and udist.is_master():
        logging.info(model_wrapper)
    assert FLAGS.profiling, '`m.macs` is used for calculating penalty'
    # if udist.is_master():
    #     model.apply(lambda m: print(m))
    if FLAGS.profiling:
        if 'gpu' in FLAGS.profiling:
            mc.profiling(model, use_cuda=True)
        if 'cpu' in FLAGS.profiling:
            mc.profiling(model, use_cuda=False)

    if FLAGS.dataset == 'cityscapes':
        (train_set, val_set,
         test_set) = seg_dataflow.cityscapes_datasets(FLAGS)
        segval = SegVal(num_classes=19)
    elif FLAGS.dataset == 'ade20k':
        (train_set, val_set, test_set) = seg_dataflow.ade20k_datasets(FLAGS)
        segval = SegVal(num_classes=150)
    elif FLAGS.dataset == 'coco':
        (train_set, val_set, test_set) = seg_dataflow.coco_datasets(FLAGS)
        # print(len(train_set), len(val_set))  # 149813 104125
        segval = None
    else:
        # data
        (train_transforms, val_transforms,
         test_transforms) = dataflow.data_transforms(FLAGS)
        (train_set, val_set,
         test_set) = dataflow.dataset(train_transforms, val_transforms,
                                      test_transforms, FLAGS)
        segval = None
    (train_loader, calib_loader, val_loader,
     test_loader) = dataflow.data_loader(train_set, val_set, test_set, FLAGS)

    # get bn's weights
    if FLAGS.prune_params.use_transformer:
        FLAGS._bn_to_prune, FLAGS._bn_to_prune_transformer = prune.get_bn_to_prune(
            model, FLAGS.prune_params)
    else:
        FLAGS._bn_to_prune = prune.get_bn_to_prune(model, FLAGS.prune_params)
    rho_scheduler = prune.get_rho_scheduler(FLAGS.prune_params,
                                            FLAGS._steps_per_epoch)

    if FLAGS.test_only and (test_loader is not None):
        if udist.is_master():
            logging.info('Start testing.')
        test_meters = mc.get_meters('test')
        validate(last_epoch, calib_loader, test_loader, criterion, test_meters,
                 model_wrapper, ema, 'test')
        return

    # already broadcast by AllReduceDistributedDataParallel
    # optimizer load same checkpoint/same initialization

    if udist.is_master():
        logging.info('Start training.')

    for epoch in range(last_epoch + 1, FLAGS.num_epochs):
        # train
        results = run_one_epoch(epoch,
                                train_loader,
                                model_wrapper,
                                criterion_smooth,
                                optimizer,
                                lr_scheduler,
                                ema,
                                rho_scheduler,
                                train_meters,
                                phase='train')

        if (epoch + 1) % FLAGS.eval_interval == 0:
            # val
            results, model_eval_wrapper = validate(epoch, calib_loader,
                                                   val_loader, criterion,
                                                   val_meters, model_wrapper,
                                                   ema, 'val', segval, val_set)

            if FLAGS.prune_params['method'] is not None and FLAGS.prune_params[
                    'bn_prune_filter'] is not None:
                prune_threshold = FLAGS.model_shrink_threshold  # 1e-3
                masks = prune.cal_mask_network_slimming_by_threshold(
                    get_prune_weights(model_eval_wrapper), prune_threshold
                )  # get mask for all bn weights (depth-wise)
                FLAGS._bn_to_prune.add_info_list('mask', masks)
                flops_pruned, infos = prune.cal_pruned_flops(
                    FLAGS._bn_to_prune)
                log_pruned_info(mc.unwrap_model(model_eval_wrapper),
                                flops_pruned, infos, prune_threshold)
                if not FLAGS.distill:
                    if flops_pruned >= FLAGS.model_shrink_delta_flops \
                            or epoch == FLAGS.num_epochs - 1:
                        ema_only = (epoch == FLAGS.num_epochs - 1)
                        shrink_model(model_wrapper, ema, optimizer,
                                     FLAGS._bn_to_prune, prune_threshold,
                                     ema_only)
            model_kwparams = mb.output_network(mc.unwrap_model(model_wrapper))

            if udist.is_master():
                if FLAGS.model_kwparams.task == 'classification' and results[
                        'top1_error'] < best_val:
                    best_val = results['top1_error']
                    logging.info(
                        'New best validation top1 error: {:.4f}'.format(
                            best_val))

                    save_status(model_wrapper, model_kwparams, optimizer, ema,
                                epoch, best_val, (train_meters, val_meters),
                                os.path.join(FLAGS.log_dir, 'best_model'))

                elif FLAGS.model_kwparams.task == 'segmentation' and FLAGS.dataset != 'coco' and results[
                        'mIoU'] > best_val:
                    best_val = results['mIoU']
                    logging.info('New seg mIoU: {:.4f}'.format(best_val))

                    save_status(model_wrapper, model_kwparams, optimizer, ema,
                                epoch, best_val, (train_meters, val_meters),
                                os.path.join(FLAGS.log_dir, 'best_model'))
                elif FLAGS.dataset == 'coco' and results > best_val:
                    best_val = results
                    logging.info('New Result: {:.4f}'.format(best_val))
                    save_status(model_wrapper, model_kwparams, optimizer, ema,
                                epoch, best_val, (train_meters, val_meters),
                                os.path.join(FLAGS.log_dir, 'best_model'))

                # save latest checkpoint
                save_status(model_wrapper, model_kwparams, optimizer, ema,
                            epoch, best_val, (train_meters, val_meters),
                            os.path.join(FLAGS.log_dir, 'latest_checkpoint'))

    return
Ejemplo n.º 16
0
def train_val_test():
    """Train and val."""
    torch.backends.cudnn.benchmark = True

    # model
    model, model_wrapper = get_model()
    criterion = torch.nn.CrossEntropyLoss(reduction='none').cuda()
    criterion_smooth = optim.CrossEntropyLabelSmooth(
        FLAGS.model_kwparams['num_classes'],
        FLAGS['label_smoothing'],
        reduction='none').cuda()
    # TODO: cal loss on all GPUs instead only `cuda:0` when non
    # distributed

    ema = None
    if FLAGS.moving_average_decay > 0.0:
        if FLAGS.moving_average_decay_adjust:
            moving_average_decay = optim.ExponentialMovingAverage.adjust_momentum(
                FLAGS.moving_average_decay,
                FLAGS.moving_average_decay_base_batch / FLAGS.batch_size)
        else:
            moving_average_decay = FLAGS.moving_average_decay
        logging.info('Moving average for model parameters: {}'.format(
            moving_average_decay))
        ema = optim.ExponentialMovingAverage(moving_average_decay)
        for name, param in model.named_parameters():
            ema.register(name, param)
        # We maintain mva for batch norm moving mean and variance as well.
        for name, buffer in model.named_buffers():
            if 'running_var' in name or 'running_mean' in name:
                ema.register(name, buffer)

    if FLAGS.get('log_graph_only', False):
        if is_root_rank:
            _input = torch.zeros(1, 3, FLAGS.image_size,
                                 FLAGS.image_size).cuda()
            _input = _input.requires_grad_(True)
            summary_writer.add_graph(model_wrapper, (_input, ), verbose=True)
        return

    # check pretrained
    if FLAGS.pretrained:
        checkpoint = torch.load(FLAGS.pretrained,
                                map_location=lambda storage, loc: storage)
        if ema:
            ema.load_state_dict(checkpoint['ema'])
            ema.to(get_device(model))
        # update keys from external models
        if isinstance(checkpoint, dict) and 'model' in checkpoint:
            checkpoint = checkpoint['model']
        if (hasattr(FLAGS, 'pretrained_model_remap_keys')
                and FLAGS.pretrained_model_remap_keys):
            new_checkpoint = {}
            new_keys = list(model_wrapper.state_dict().keys())
            old_keys = list(checkpoint.keys())
            for key_new, key_old in zip(new_keys, old_keys):
                new_checkpoint[key_new] = checkpoint[key_old]
                logging.info('remap {} to {}'.format(key_new, key_old))
            checkpoint = new_checkpoint
        model_wrapper.load_state_dict(checkpoint)
        logging.info('Loaded model {}.'.format(FLAGS.pretrained))
    optimizer = optim.get_optimizer(model_wrapper, FLAGS)

    # check resume training
    if FLAGS.resume:
        checkpoint = torch.load(os.path.join(FLAGS.resume,
                                             'latest_checkpoint.pt'),
                                map_location=lambda storage, loc: storage)
        model_wrapper.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        if ema:
            ema.load_state_dict(checkpoint['ema'])
            ema.to(get_device(model))
        last_epoch = checkpoint['last_epoch']
        lr_scheduler = optim.get_lr_scheduler(optimizer, FLAGS)
        lr_scheduler.last_epoch = (last_epoch + 1) * FLAGS._steps_per_epoch
        best_val = extract_item(checkpoint['best_val'])
        train_meters, val_meters = checkpoint['meters']
        FLAGS._global_step = (last_epoch + 1) * FLAGS._steps_per_epoch
        if is_root_rank:
            logging.info('Loaded checkpoint {} at epoch {}.'.format(
                FLAGS.resume, last_epoch))
    else:
        lr_scheduler = optim.get_lr_scheduler(optimizer, FLAGS)
        # last_epoch = lr_scheduler.last_epoch
        last_epoch = -1
        best_val = 1.
        train_meters = get_meters('train')
        val_meters = get_meters('val')
        FLAGS._global_step = 0

    if not FLAGS.resume and is_root_rank:
        logging.info(model_wrapper)
    assert FLAGS.profiling, '`m.macs` is used for calculating penalty'
    if FLAGS.profiling:
        if 'gpu' in FLAGS.profiling:
            profiling(model, use_cuda=True)
        if 'cpu' in FLAGS.profiling:
            profiling(model, use_cuda=False)

    # data
    (train_transforms, val_transforms,
     test_transforms) = dataflow.data_transforms(FLAGS)
    (train_set, val_set, test_set) = dataflow.dataset(train_transforms,
                                                      val_transforms,
                                                      test_transforms, FLAGS)
    (train_loader, calib_loader, val_loader,
     test_loader) = dataflow.data_loader(train_set, val_set, test_set, FLAGS)

    # get bn's weights
    FLAGS._bn_to_prune = prune.get_bn_to_prune(model, FLAGS.prune_params)
    rho_scheduler = prune.get_rho_scheduler(FLAGS.prune_params,
                                            FLAGS._steps_per_epoch)

    if FLAGS.test_only and (test_loader is not None):
        if is_root_rank:
            logging.info('Start testing.')
        test_meters = get_meters('test')
        validate(last_epoch, calib_loader, test_loader, criterion, test_meters,
                 model_wrapper, ema, 'test')
        return

    # already broadcast by AllReduceDistributedDataParallel
    # optimizer load same checkpoint/same initialization

    if is_root_rank:
        logging.info('Start training.')

    for epoch in range(last_epoch + 1, FLAGS.num_epochs):
        # train
        results = run_one_epoch(epoch,
                                train_loader,
                                model_wrapper,
                                criterion_smooth,
                                optimizer,
                                lr_scheduler,
                                ema,
                                rho_scheduler,
                                train_meters,
                                phase='train')

        # val
        results, model_eval_wrapper = validate(epoch, calib_loader, val_loader,
                                               criterion, val_meters,
                                               model_wrapper, ema, 'val')

        if FLAGS.prune_params['method'] is not None:
            prune_threshold = FLAGS.model_shrink_threshold
            masks = prune.cal_mask_network_slimming_by_threshold(
                get_prune_weights(model_eval_wrapper), prune_threshold)
            FLAGS._bn_to_prune.add_info_list('mask', masks)
            flops_pruned, infos = prune.cal_pruned_flops(FLAGS._bn_to_prune)
            log_pruned_info(unwrap_model(model_eval_wrapper), flops_pruned,
                            infos, prune_threshold)
            if flops_pruned >= FLAGS.model_shrink_delta_flops \
                    or epoch == FLAGS.num_epochs - 1:
                ema_only = (epoch == FLAGS.num_epochs - 1)
                shrink_model(model_wrapper, ema, optimizer, FLAGS._bn_to_prune,
                             prune_threshold, ema_only)
        model_kwparams = mb.output_network(unwrap_model(model_wrapper))

        if results['top1_error'] < best_val:
            best_val = results['top1_error']

            if is_root_rank:
                save_status(model_wrapper, model_kwparams, optimizer, ema,
                            epoch, best_val, (train_meters, val_meters),
                            os.path.join(FLAGS.log_dir, 'best_model'))
                logging.info(
                    'New best validation top1 error: {:.4f}'.format(best_val))

        if is_root_rank:
            # save latest checkpoint
            save_status(model_wrapper, model_kwparams, optimizer, ema, epoch,
                        best_val, (train_meters, val_meters),
                        os.path.join(FLAGS.log_dir, 'latest_checkpoint'))

        # NOTE: from scheduler code, should be called after train/val
        # use stepwise scheduler instead
        # lr_scheduler.step()
    return