Example #1
0
def finetune_first_image(model, images, targets, optimizer,scheduler, logger, cfg):
    total_iter_finetune = cfg.FINETUNE.TOTAL_ITER
    model.train()
    meters = MetricLogger(delimiter="  ")
    for iteration in range(total_iter_finetune):

        scheduler.step()
        loss_dict, _ = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = reduce_loss_dict(loss_dict)
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())
        meters.update(total_loss=losses_reduced, **loss_dict_reduced)

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        meters.update(lr=optimizer.param_groups[0]["lr"])

        if iteration % (total_iter_finetune / 2) == 0 :
            logger.info(
                meters.delimiter.join(
                    [
                        "{meters}",
                    ]
                ).format(
                    meters=str(meters),
                )
            )

    model.eval()
    return model
Example #2
0
def trainIters(args):
    """ main: loop over different epoch. and datasplit """
    epoch_resume = args.epoch_resume
    model_dir = os.path.join(args.models_root, args.model_name)
    board_dir = os.path.join(args.models_root, 'boards', args.model_name)
    if args.local_rank == 0:
        make_dir(board_dir)
        make_dir(model_dir)
    start = time.time()
    meters = {
        args.train_split: MetricLogger(delimiter=" "),
        args.eval_split: MetricLogger(delimiter=" ")
    }
    args.model_dir = model_dir
    enc_opt, dec_opt, trainer = build_model(args)
    max_eval_iter = args.max_eval_iter
    # save parameters for future use
    if args.local_rank == 0:
        tb_writer = SummaryWriter(board_dir)
        pickle.dump(args,
                    open(os.path.join(model_dir, timestr + '_args.pkl'), 'wb'))
        pickle.dump(args, open(os.path.join(model_dir, 'args.pkl'),
                               'wb'))  # overwrite the latest args
        logging.info('save args in %s' %
                     os.path.join(model_dir, timestr + 'args.pkl'))
        logging.info('{}'.format(args))
    start = time.time()
    # vars for early stopping
    best_val_loss = args.best_val_loss
    best_val_epo = 0
    acc_patience = 0
    mt_val = -1

    # keep track of the number of batches in each epoch for continuity when plotting curves
    if args.local_rank == 0: logging.info('init_dataloaders')
    start = time.time()
    loaders = init_dataloaders(args)
    num_batches = {args.train_split: 0, args.eval_split: 0}

    if args.local_rank == 0:
        logging.info('dataloader %.3f' % (time.time() - start))

    for e in range(args.max_epoch):
        # check if it's time to do some changes here
        if e + epoch_resume >= args.finetune_after and not args.sample_inference_mask and not args.finetune_after == -1:
            args.sample_inference_mask = 1
            logging.info('=' * 10 + '> start sample_inference_mask')
            acc_patience, best_val_loss = 0, 0
        # in current epoch, loop over split
        # we validate after each epoch
        if max_eval_iter > 0 and e == 0:
            splits = [args.eval_split, args.train_split, args.eval_split]
        elif max_eval_iter == 0:
            splits = [args.train_split]
        else:
            splits = [args.train_split, args.eval_split]
        for split in splits:
            if split == args.eval_split:
                trainer.eval()
            # loop over batches in current epoch
            if args.local_rank == 0:
                logging.info('epoch %d - %s; ' % (e + epoch_resume, split))
                logging.info(
                    '-- loss weight loss_weight_match: {} loss_weight_iouraw {}; '
                    .format(args.loss_weight_match, args.loss_weight_iouraw))
                sd = time.time()
                start = time.time()
            iter_time = []
            for batch_idx, (inputs, imgs_names, targets, seq_name,
                            starting_frame) in enumerate(loaders[split]):
                # imgs_names: can be proposals: List[tuple(BoxList)], len of list=Nframe, len-of-tuple=BatchSize
                if args.local_rank == 0:
                    start_iter = time.time()
                    dataT = time.time() - sd
                assert (type(targets) == list)
                inputs = [sub.to(args.device) for sub in inputs]
                targets = [sub.to(args.device) for sub in targets]
                if args.load_proposals_dataset:
                    proposals_cur_batch = imgs_names  # len=framelen
                    proposals = []  # BoxList of current batch
                    for p in proposals_cur_batch:
                        boxlist = list(p)  # BoxList of current batch
                        proposals.append([b.to(args.device)
                                          for b in boxlist])  # len=BatchSize
                    imgs_names = None
                else:
                    proposals = None

                # forward
                if split == args.eval_split:
                    with torch.no_grad():
                        loss, losses = trainer(batch_idx, inputs, imgs_names,
                                               targets, seq_name,
                                               starting_frame, split, args,
                                               proposals)
                else:
                    loss, losses = trainer(batch_idx, inputs, imgs_names,
                                           targets, seq_name, starting_frame,
                                           split, args, proposals)
                ## import pdb; pdb.set_trace()
                #if DEBUG:
                #    logging.info('>> profile ')
                #    logging.info('seq_name {}, inputs sum {}; proposals: {} imgs_names {}'.format(seq_name, inputs[0].sum(), proposals[0][0].bbox.sum(), imgs_names))
                #    info = {'batch_idx': batch_idx, 'info':[seq_name, inputs[0].shape, inputs[0].sum(), proposals[0][0].bbox.sum(), imgs_names, losses, loss]}
                #    check_info = torch.load('../../drvos/src/debug/%d.pth'%batch_idx)
                #    CHECKDEBUG(info, check_info)

                loss = loss.mean()  #reduce_loss_dict({'loss':loss})

                if split == args.train_split:  # and args.local_rank == 0:
                    dec_opt.zero_grad()
                    enc_opt.zero_grad()
                    if loss.requires_grad:
                        loss.backward()
                        if args.distributed:
                            average_gradients(trainer, args.local_rank)
                            torch.cuda.synchronize()
                        dec_opt.step()
                        enc_opt.step()
                # record the losses
                # store loss values in dictionary separately
                if args.distributed:
                    losses = reduce_loss_dict(losses)

                if args.ngpus > 1 and args.local_rank == 0:
                    for k, v in losses.items():
                        if not args.distributed:
                            losses[k] = v.mean()
                        tb_writer.add_scalar(
                            '%s/%s' % (k, split), losses[k], batch_idx +
                            (e + epoch_resume) * len(loaders[split]))
                elif args.local_rank == 0:
                    for k, v in losses.items():
                        tb_writer.add_scalar(
                            '%s/%s' % (k, split), v, batch_idx +
                            (e + epoch_resume) * len(loaders[split]))
                if args.local_rank == 0: meters[split].update(**losses)
                # print after some iterations
                if (
                        batch_idx + 1
                ) % args.print_every == 0 and args.local_rank == 0:  # iteration
                    te = time.time() - start_iter
                    iter_time.append(te)
                    remain_t = (
                        sum(iter_time) / len(iter_time) *
                        (len(loaders[split]) - batch_idx)) / 60.0 / 60.0
                    max_mem = "mem: {memory:.0f}".format(
                        memory=torch.cuda.max_memory_allocated() / 1024.0 /
                        1024.0)
                    meters[split].update(time=te, dt=dataT)
                    logging.info("%s:%s:p%d(%d-%.2f):E%d it%d/%d: rt(%.2fh) %s|%s"%(args.model_name, split, acc_patience, \
                            best_val_epo, best_val_loss, (e+epoch_resume), batch_idx, len(loaders[split]), remain_t, \
                            str(meters[split]), max_mem))
                    start = time.time()
                if args.local_rank == 0 and split == args.train_split  and (((batch_idx + 1) % args.save_every == 0) \
                        or batch_idx + 1 == len(loaders[split])):
                    logging.info('save model at {} {}'.format(
                        batch_idx, e + epoch_resume))
                    save_checkpoint_iter(
                        trainer, args,
                        'epo%02d_iter%05d' % (e + epoch_resume, batch_idx),
                        enc_opt, dec_opt)
                sd = time.time()
            # out of for-all-batches in current split
            num_batches[split] = batch_idx + 1
            # for loss_name in ['loss', 'match_loss', 'iou', 'iouraw', 'hard_iou_raw']:
            if split == args.eval_split:
                for loss_name in ['hard_iou_raw', 'hard_iou'
                                  ]:  # prefer hard_iou than hard_iou_raw
                    if loss_name in meters[
                            args.eval_split].fields() and max_eval_iter != 0:
                        mt_val = meters[args.eval_split].load_field(
                            loss_name).global_avg
                meters[args.eval_split] = MetricLogger(delimiter=" ")
                if mt_val > (best_val_loss + args.min_delta):
                    logging.info("Saving checkpoint.")
                    best_val_loss = mt_val
                    best_val_epo = e + epoch_resume
                    # saves model, params, and optimizers
                    save_checkpoint_iter(
                        trainer, args, 'best_%.3f_epo%02d' %
                        (best_val_loss, e + epoch_resume), enc_opt, dec_opt)
                    acc_patience = 0
                else:
                    acc_patience += 1
        if acc_patience > args.patience_stop:
            logging.info('acc_patience reach maximum, I killed my self: Bye ')
            break
Example #3
0
def train(cfg, local_rank, distributed, logger):
    if is_main_process():
        wandb.init(project='scene-graph',
                   entity='sgg-speaker-listener',
                   config=cfg.LISTENER)
    debug_print(logger, 'prepare training')

    model = build_detection_model(cfg)
    listener = build_listener(cfg)

    speaker_listener = SpeakerListener(model,
                                       listener,
                                       cfg,
                                       is_joint=cfg.LISTENER.JOINT)
    if is_main_process():
        wandb.watch(listener)

    debug_print(logger, 'end model construction')

    # modules that should be always set in eval mode
    # their eval() method should be called after model.train() is called
    eval_modules = (
        model.rpn,
        model.backbone,
        model.roi_heads.box,
    )

    fix_eval_modules(eval_modules)

    # NOTE, we slow down the LR of the layers start with the names in slow_heads
    if cfg.MODEL.ROI_RELATION_HEAD.PREDICTOR == "IMPPredictor":
        slow_heads = [
            "roi_heads.relation.box_feature_extractor",
            "roi_heads.relation.union_feature_extractor.feature_extractor",
        ]
    else:
        slow_heads = []

    # load pretrain layers to new layers
    load_mapping = {
        "roi_heads.relation.box_feature_extractor":
        "roi_heads.box.feature_extractor",
        "roi_heads.relation.union_feature_extractor.feature_extractor":
        "roi_heads.box.feature_extractor"
    }

    if cfg.MODEL.ATTRIBUTE_ON:
        load_mapping[
            "roi_heads.relation.att_feature_extractor"] = "roi_heads.attribute.feature_extractor"
        load_mapping[
            "roi_heads.relation.union_feature_extractor.att_feature_extractor"] = "roi_heads.attribute.feature_extractor"

    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)
    listener.to(device)

    num_gpus = int(
        os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    num_batch = cfg.SOLVER.IMS_PER_BATCH

    optimizer = make_optimizer(cfg,
                               model,
                               logger,
                               slow_heads=slow_heads,
                               slow_ratio=10.0,
                               rl_factor=float(num_batch))
    listener_optimizer = make_listener_optimizer(cfg, listener)
    scheduler = make_lr_scheduler(cfg, optimizer, logger)
    listener_scheduler = None
    debug_print(logger, 'end optimizer and schedule')

    if cfg.LISTENER.JOINT:
        speaker_listener_optimizer = make_speaker_listener_optimizer(
            cfg, speaker_listener.speaker, speaker_listener.listener)

    # Initialize mixed-precision training
    use_mixed_precision = cfg.DTYPE == "float16"
    amp_opt_level = 'O1' if use_mixed_precision else 'O0'

    if cfg.LISTENER.JOINT:
        speaker_listener, speaker_listener_optimizer = amp.initialize(
            speaker_listener, speaker_listener_optimizer, opt_level='O0')
    else:
        speaker_listener, listener_optimizer = amp.initialize(
            speaker_listener, listener_optimizer, opt_level='O0')

    #listener, listener_optimizer = amp.initialize(listener, listener_optimizer, opt_level='O0')
    #[model, listener], [optimizer, listener_optimizer] = amp.initialize([model, listener], [optimizer, listener_optimizer], opt_level='O1', loss_scale=1)
    #model = amp.initialize(model, opt_level='O1')

    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[local_rank],
            output_device=local_rank,
            # this should be removed if we update BatchNorm stats
            broadcast_buffers=False,
            find_unused_parameters=True,
        )

        listener = torch.nn.parallel.DistributedDataParallel(
            listener,
            device_ids=[local_rank],
            output_device=local_rank,
            # this should be removed if we update BatchNorm stats
            broadcast_buffers=False,
            find_unused_parameters=True,
        )

    debug_print(logger, 'end distributed')
    arguments = {}
    arguments["iteration"] = 0

    output_dir = cfg.OUTPUT_DIR
    listener_dir = cfg.LISTENER_DIR
    save_to_disk = get_rank() == 0

    speaker_checkpointer = DetectronCheckpointer(cfg,
                                                 model,
                                                 optimizer,
                                                 scheduler,
                                                 output_dir,
                                                 save_to_disk,
                                                 custom_scheduler=True)

    listener_checkpointer = Checkpointer(listener,
                                         optimizer=listener_optimizer,
                                         save_dir=listener_dir,
                                         save_to_disk=save_to_disk,
                                         custom_scheduler=False)

    speaker_listener.add_listener_checkpointer(listener_checkpointer)
    speaker_listener.add_speaker_checkpointer(speaker_checkpointer)

    speaker_listener.load_listener()
    speaker_listener.load_speaker(load_mapping=load_mapping)
    debug_print(logger, 'end load checkpointer')
    train_data_loader = make_data_loader(cfg,
                                         mode='train',
                                         is_distributed=distributed,
                                         start_iter=arguments["iteration"],
                                         ret_images=True)
    val_data_loaders = make_data_loader(cfg,
                                        mode='val',
                                        is_distributed=distributed,
                                        ret_images=True)

    debug_print(logger, 'end dataloader')
    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD

    if cfg.SOLVER.PRE_VAL:
        logger.info("Validate before training")
        #output =  run_val(cfg, model, listener, val_data_loaders, distributed, logger)
        #print('OUTPUT: ', output)
        #(sg_loss, img_loss, sg_acc, img_acc) = output

    logger.info("Start training")
    meters = MetricLogger(delimiter="  ")
    max_iter = len(train_data_loader)
    start_iter = arguments["iteration"]
    start_training_time = time.time()
    end = time.time()

    print_first_grad = True

    listener_loss_func = torch.nn.MarginRankingLoss(margin=1, reduction='none')
    mistake_saver = None
    if is_main_process():
        ds_catalog = DatasetCatalog()
        dict_file_path = os.path.join(
            ds_catalog.DATA_DIR,
            ds_catalog.DATASETS['VG_stanford_filtered_with_attribute']
            ['dict_file'])
        ind_to_classes, ind_to_predicates = load_vg_info(dict_file_path)
        ind_to_classes = {k: v for k, v in enumerate(ind_to_classes)}
        ind_to_predicates = {k: v for k, v in enumerate(ind_to_predicates)}
        print('ind to classes:', ind_to_classes, '/n ind to predicates:',
              ind_to_predicates)
        mistake_saver = MistakeSaver(
            '/Scene-Graph-Benchmark.pytorch/filenames_masked', ind_to_classes,
            ind_to_predicates)

    #is_printed = False
    while True:
        try:
            listener_iteration = 0
            for iteration, (images, targets,
                            image_ids) in enumerate(train_data_loader,
                                                    start_iter):

                if cfg.LISTENER.JOINT:
                    speaker_listener_optimizer.zero_grad()
                else:
                    listener_optimizer.zero_grad()

                #print(f'ITERATION NUMBER: {iteration}')
                if any(len(target) < 1 for target in targets):
                    logger.error(
                        f"Iteration={iteration + 1} || Image Ids used for training {_} || targets Length={[len(target) for target in targets]}"
                    )
                if len(images) <= 1:
                    continue

                data_time = time.time() - end
                iteration = iteration + 1
                listener_iteration += 1
                arguments["iteration"] = iteration
                model.train()
                fix_eval_modules(eval_modules)
                images_list = deepcopy(images)
                images_list = to_image_list(
                    images_list, cfg.DATALOADER.SIZE_DIVISIBILITY).to(device)

                for i in range(len(images)):
                    images[i] = images[i].unsqueeze(0)
                    images[i] = F.interpolate(images[i],
                                              size=(224, 224),
                                              mode='bilinear',
                                              align_corners=False)
                    images[i] = images[i].squeeze()

                images = torch.stack(images).to(device)
                #images.requires_grad_()

                targets = [target.to(device) for target in targets]

                speaker_loss_dict = {}
                if not cfg.LISTENER.JOINT:
                    score_matrix = speaker_listener(images_list, targets,
                                                    images)
                else:
                    score_matrix, _, speaker_loss_dict = speaker_listener(
                        images_list, targets, images)

                speaker_summed_losses = sum(
                    loss for loss in speaker_loss_dict.values())

                # reduce losses over all GPUs for logging purposes
                if not not cfg.LISTENER.JOINT:
                    speaker_loss_dict_reduced = reduce_loss_dict(
                        speaker_loss_dict)
                    speaker_losses_reduced = sum(
                        loss for loss in speaker_loss_dict_reduced.values())
                    speaker_losses_reduced /= num_gpus

                    if is_main_process():
                        wandb.log(
                            {"Train Speaker Loss": speaker_losses_reduced},
                            listener_iteration)

                listener_loss = 0
                gap_reward = 0
                avg_acc = 0
                num_correct = 0

                score_matrix = score_matrix.to(device)
                # fill loss matrix
                loss_matrix = torch.zeros((2, images.size(0), images.size(0)),
                                          device=device)
                # sg centered scores
                for true_index in range(loss_matrix.size(1)):
                    row_score = score_matrix[true_index]
                    (true_scores, predicted_scores,
                     binary) = format_scores(row_score, true_index, device)
                    loss_vec = listener_loss_func(true_scores,
                                                  predicted_scores, binary)
                    loss_matrix[0][true_index] = loss_vec
                # image centered scores
                transposted_score_matrix = score_matrix.t()
                for true_index in range(loss_matrix.size(1)):
                    row_score = transposted_score_matrix[true_index]
                    (true_scores, predicted_scores,
                     binary) = format_scores(row_score, true_index, device)
                    loss_vec = listener_loss_func(true_scores,
                                                  predicted_scores, binary)
                    loss_matrix[1][true_index] = loss_vec

                print('iteration:', listener_iteration)
                sg_acc = 0
                img_acc = 0
                # calculate accuracy
                for i in range(loss_matrix.size(1)):
                    temp_sg_acc = 0
                    temp_img_acc = 0
                    for j in range(loss_matrix.size(2)):
                        if loss_matrix[0][i][i] > loss_matrix[0][i][j]:
                            temp_sg_acc += 1
                        else:
                            if cfg.LISTENER.HTML:
                                if is_main_process(
                                ) and listener_iteration >= 600 and listener_iteration % 25 == 0 and i != j:
                                    detached_sg_i = (sgs[i][0].detach(),
                                                     sgs[i][1],
                                                     sgs[i][2].detach())
                                    detached_sg_j = (sgs[j][0].detach(),
                                                     sgs[j][1],
                                                     sgs[j][2].detach())
                                    mistake_saver.add_mistake(
                                        (image_ids[i], image_ids[j]),
                                        (detached_sg_i, detached_sg_j),
                                        listener_iteration, 'SG')
                        if loss_matrix[1][i][i] > loss_matrix[1][j][i]:
                            temp_img_acc += 1
                        else:
                            if cfg.LISTENER.HTML:
                                if is_main_process(
                                ) and listener_iteration >= 600 and listener_iteration % 25 == 0 and i != j:
                                    detached_sg_i = (sgs[i][0].detach(),
                                                     sgs[i][1],
                                                     sgs[i][2].detach())
                                    detached_sg_j = (sgs[j][0].detach(),
                                                     sgs[j][1],
                                                     sgs[j][2].detach())
                                    mistake_saver.add_mistake(
                                        (image_ids[i], image_ids[j]),
                                        (detached_sg_i, detached_sg_j),
                                        listener_iteration, 'IMG')

                    temp_sg_acc = temp_sg_acc * 100 / (loss_matrix.size(1) - 1)
                    temp_img_acc = temp_img_acc * 100 / (loss_matrix.size(1) -
                                                         1)
                    sg_acc += temp_sg_acc
                    img_acc += temp_img_acc
                if cfg.LISTENER.HTML:
                    if is_main_process(
                    ) and listener_iteration % 100 == 0 and listener_iteration >= 600:
                        mistake_saver.toHtml('/www')

                sg_acc /= loss_matrix.size(1)
                img_acc /= loss_matrix.size(1)

                avg_sg_acc = torch.tensor([sg_acc]).to(device)
                avg_img_acc = torch.tensor([img_acc]).to(device)
                # reduce acc over all gpus
                avg_acc = {'sg_acc': avg_sg_acc, 'img_acc': avg_img_acc}
                avg_acc_reduced = reduce_loss_dict(avg_acc)

                sg_acc = sum(acc for acc in avg_acc_reduced['sg_acc'])
                img_acc = sum(acc for acc in avg_acc_reduced['img_acc'])

                # log acc to wadb
                if is_main_process():
                    wandb.log({
                        "Train SG Accuracy": sg_acc.item(),
                        "Train IMG Accuracy": img_acc.item()
                    })

                sg_loss = 0
                img_loss = 0

                for i in range(loss_matrix.size(0)):
                    for j in range(loss_matrix.size(1)):
                        loss_matrix[i][j][j] = 0.

                for i in range(loss_matrix.size(1)):
                    sg_loss += torch.max(loss_matrix[0][i])
                    img_loss += torch.max(loss_matrix[1][:][i])

                sg_loss = sg_loss / loss_matrix.size(1)
                img_loss = img_loss / loss_matrix.size(1)
                sg_loss = sg_loss.to(device)
                img_loss = img_loss.to(device)

                loss_dict = {'sg_loss': sg_loss, 'img_loss': img_loss}

                losses = sum(loss for loss in loss_dict.values())

                # reduce losses over all GPUs for logging purposes
                loss_dict_reduced = reduce_loss_dict(loss_dict)
                sg_loss_reduced = loss_dict_reduced['sg_loss']
                img_loss_reduced = loss_dict_reduced['img_loss']
                if is_main_process():
                    wandb.log({"Train SG Loss": sg_loss_reduced})
                    wandb.log({"Train IMG Loss": img_loss_reduced})

                losses_reduced = sum(loss
                                     for loss in loss_dict_reduced.values())
                meters.update(loss=losses_reduced, **loss_dict_reduced)

                losses = losses + speaker_summed_losses * cfg.LISTENER.LOSS_COEF
                # Note: If mixed precision is not used, this ends up doing nothing
                # Otherwise apply loss scaling for mixed-precision recipe
                #losses.backward()
                if not cfg.LISTENER.JOINT:
                    with amp.scale_loss(losses,
                                        listener_optimizer) as scaled_losses:
                        scaled_losses.backward()
                else:
                    with amp.scale_loss(
                            losses,
                            speaker_listener_optimizer) as scaled_losses:
                        scaled_losses.backward()

                verbose = (iteration % cfg.SOLVER.PRINT_GRAD_FREQ
                           ) == 0 or print_first_grad  # print grad or not
                print_first_grad = False
                #clip_grad_value([(n, p) for n, p in listener.named_parameters() if p.requires_grad], cfg.LISTENER.CLIP_VALUE, logger=logger, verbose=True, clip=True)
                if not cfg.LISTENER.JOINT:
                    listener_optimizer.step()
                else:
                    speaker_listener_optimizer.step()

                batch_time = time.time() - end
                end = time.time()
                meters.update(time=batch_time, data=data_time)

                eta_seconds = meters.time.global_avg * (max_iter - iteration)
                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

                if cfg.LISTENER.JOINT:
                    if iteration % 200 == 0 or iteration == max_iter:
                        logger.info(
                            meters.delimiter.join([
                                "eta: {eta}",
                                "iter: {iter}",
                                "{meters}",
                                "lr: {lr:.6f}",
                                "max mem: {memory:.0f}",
                            ]).format(
                                eta=eta_string,
                                iter=iteration,
                                meters=str(meters),
                                lr=speaker_listener_optimizer.param_groups[-1]
                                ["lr"],
                                memory=torch.cuda.max_memory_allocated() /
                                1024.0 / 1024.0,
                            ))
                else:
                    if iteration % 200 == 0 or iteration == max_iter:
                        logger.info(
                            meters.delimiter.join([
                                "eta: {eta}",
                                "iter: {iter}",
                                "{meters}",
                                "lr: {lr:.6f}",
                                "max mem: {memory:.0f}",
                            ]).format(
                                eta=eta_string,
                                iter=iteration,
                                meters=str(meters),
                                lr=listener_optimizer.param_groups[-1]["lr"],
                                memory=torch.cuda.max_memory_allocated() /
                                1024.0 / 1024.0,
                            ))

                if iteration % checkpoint_period == 0:
                    """
                    print('Model before save')
                    print('****************************')
                    print(listener.gnn.conv1.node_model.node_mlp_1[0].weight)
                    print('****************************')
                    """
                    if not cfg.LISTENER.JOINT:
                        listener_checkpointer.save(
                            "model_{:07d}".format(listener_iteration),
                            amp=amp.state_dict())
                    else:
                        speaker_checkpointer.save(
                            "model_speaker{:07d}".format(iteration))
                        listener_checkpointer.save(
                            "model_listenr{:07d}".format(listener_iteration),
                            amp=amp.state_dict())
                if iteration == max_iter:
                    if not cfg.LISTENER.JOINT:
                        listener_checkpointer.save(
                            "model_{:07d}".format(listener_iteration),
                            amp=amp.state_dict())
                    else:
                        speaker_checkpointer.save(
                            "model_{:07d}".format(iteration))
                        listener_checkpointer.save(
                            "model_{:07d}".format(listener_iteration),
                            amp=amp.state_dict())

                val_result = None  # used for scheduler updating
                if cfg.SOLVER.TO_VAL and iteration % cfg.SOLVER.VAL_PERIOD == 0:
                    logger.info("Start validating")
                    val_result = run_val(cfg, model, listener,
                                         val_data_loaders, distributed, logger)
                    (sg_loss, img_loss, sg_acc, img_acc,
                     speaker_val) = val_result

                    if is_main_process():
                        wandb.log({
                            "Validation SG Accuracy": sg_acc,
                            "Validation IMG Accuracy": img_acc,
                            "Validation SG Loss": sg_loss,
                            "Validation IMG Loss": img_loss,
                            "Validation Speaker": speaker_val,
                        })

                    #logger.info("Validation Result: %.4f" % val_result)
        except Exception as err:
            raise (err)
            print('Dataset finished, creating new')
            train_data_loader = make_data_loader(
                cfg,
                mode='train',
                is_distributed=distributed,
                start_iter=arguments["iteration"],
                ret_images=True)

    total_training_time = time.time() - start_training_time
    total_time_str = str(datetime.timedelta(seconds=total_training_time))
    logger.info("Total training time: {} ({:.4f} s / it)".format(
        total_time_str, total_training_time / (max_iter)))
    return listener
Example #4
0
def train(cfg, local_rank, distributed, logger):
    model = build_detection_model(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)

    optimizer = make_optimizer(cfg,
                               model,
                               logger,
                               rl_factor=float(cfg.SOLVER.IMS_PER_BATCH))
    scheduler = make_lr_scheduler(cfg, optimizer)

    # Initialize mixed-precision training
    use_mixed_precision = cfg.DTYPE == "float16"
    amp_opt_level = 'O1' if use_mixed_precision else 'O0'
    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      opt_level=amp_opt_level)

    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[local_rank],
            output_device=local_rank,
            # this should be removed if we update BatchNorm stats
            broadcast_buffers=False,
        )

    arguments = {}
    arguments["iteration"] = 0

    output_dir = cfg.OUTPUT_DIR

    save_to_disk = get_rank() == 0
    checkpointer = DetectronCheckpointer(cfg, model, optimizer, scheduler,
                                         output_dir, save_to_disk)
    extra_checkpoint_data = checkpointer.load(
        cfg.MODEL.WEIGHT,
        update_schedule=cfg.SOLVER.UPDATE_SCHEDULE_DURING_LOAD)
    arguments.update(extra_checkpoint_data)

    train_data_loader = make_data_loader(
        cfg,
        mode='train',
        is_distributed=distributed,
        start_iter=arguments["iteration"],
    )
    val_data_loaders = make_data_loader(
        cfg,
        mode='val',
        is_distributed=distributed,
    )

    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD

    if cfg.SOLVER.PRE_VAL:
        logger.info("Validate before training")
        run_val(cfg, model, val_data_loaders, distributed)

    logger.info("Start training")
    meters = MetricLogger(delimiter="  ")
    max_iter = len(train_data_loader)
    start_iter = arguments["iteration"]
    start_training_time = time.time()
    end = time.time()
    for iteration, (images, targets, _) in enumerate(train_data_loader,
                                                     start_iter):
        model.train()

        if any(len(target) < 1 for target in targets):
            logger.error(
                f"Iteration={iteration + 1} || Image Ids used for training {_} || targets Length={[len(target) for target in targets]}"
            )
        data_time = time.time() - end
        iteration = iteration + 1
        arguments["iteration"] = iteration

        scheduler.step()

        images = images.to(device)
        targets = [target.to(device) for target in targets]

        loss_dict = model(images, targets)

        losses = sum(loss for loss in loss_dict.values())

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = reduce_loss_dict(loss_dict)
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())
        meters.update(loss=losses_reduced, **loss_dict_reduced)

        optimizer.zero_grad()
        # Note: If mixed precision is not used, this ends up doing nothing
        # Otherwise apply loss scaling for mixed-precision recipe
        with amp.scale_loss(losses, optimizer) as scaled_losses:
            scaled_losses.backward()
        optimizer.step()

        batch_time = time.time() - end
        end = time.time()
        meters.update(time=batch_time, data=data_time)

        eta_seconds = meters.time.global_avg * (max_iter - iteration)
        eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

        if iteration % 200 == 0 or iteration == max_iter:
            logger.info(
                meters.delimiter.join([
                    "eta: {eta}",
                    "iter: {iter}",
                    "{meters}",
                    "lr: {lr:.6f}",
                    "max mem: {memory:.0f}",
                ]).format(
                    eta=eta_string,
                    iter=iteration,
                    meters=str(meters),
                    lr=optimizer.param_groups[0]["lr"],
                    memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
                ))

        if cfg.SOLVER.TO_VAL and iteration % cfg.SOLVER.VAL_PERIOD == 0:
            logger.info("Start validating")
            run_val(cfg, model, val_data_loaders, distributed)

        if iteration % checkpoint_period == 0:
            checkpointer.save("model_{:07d}".format(iteration), **arguments)
        if iteration == max_iter:
            checkpointer.save("model_final", **arguments)

    total_training_time = time.time() - start_training_time
    total_time_str = str(datetime.timedelta(seconds=total_training_time))
    logger.info("Total training time: {} ({:.4f} s / it)".format(
        total_time_str, total_training_time / (max_iter)))

    return model
Example #5
0
def train(cfg, local_rank, distributed, logger):
    debug_print(logger, 'prepare training')
    model = build_detection_model(cfg) 
    debug_print(logger, 'end model construction')

    # modules that should be always set in eval mode
    # their eval() method should be called after model.train() is called
    eval_modules = (model.rpn, model.backbone, model.roi_heads.box,)
 
    fix_eval_modules(eval_modules)

    # NOTE, we slow down the LR of the layers start with the names in slow_heads
    if cfg.MODEL.ROI_RELATION_HEAD.PREDICTOR == "IMPPredictor":
        slow_heads = ["roi_heads.relation.box_feature_extractor",
                      "roi_heads.relation.union_feature_extractor.feature_extractor",]
    else:
        slow_heads = []

    # load pretrain layers to new layers
    load_mapping = {"roi_heads.relation.box_feature_extractor" : "roi_heads.box.feature_extractor",
                    "roi_heads.relation.union_feature_extractor.feature_extractor" : "roi_heads.box.feature_extractor"}
    
    if cfg.MODEL.ATTRIBUTE_ON:
        load_mapping["roi_heads.relation.att_feature_extractor"] = "roi_heads.attribute.feature_extractor"
        load_mapping["roi_heads.relation.union_feature_extractor.att_feature_extractor"] = "roi_heads.attribute.feature_extractor"

    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)

    num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    num_batch = cfg.SOLVER.IMS_PER_BATCH
    optimizer = make_optimizer(cfg, model, logger, slow_heads=slow_heads, slow_ratio=10.0, rl_factor=float(num_batch))
    scheduler = make_lr_scheduler(cfg, optimizer, logger)
    debug_print(logger, 'end optimizer and shcedule')
    # Initialize mixed-precision training
    use_mixed_precision = cfg.DTYPE == "float16"
    amp_opt_level = 'O1' if use_mixed_precision else 'O0'
    model, optimizer = amp.initialize(model, optimizer, opt_level=amp_opt_level)

    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[local_rank], output_device=local_rank,
            # this should be removed if we update BatchNorm stats
            broadcast_buffers=False,
            find_unused_parameters=True,
        )
    debug_print(logger, 'end distributed')
    arguments = {}
    arguments["iteration"] = 0

    output_dir = cfg.OUTPUT_DIR

    save_to_disk = get_rank() == 0
    checkpointer = DetectronCheckpointer(
        cfg, model, optimizer, scheduler, output_dir, save_to_disk, custom_scheduler=True
    )
    # if there is certain checkpoint in output_dir, load it, else load pretrained detector
    if checkpointer.has_checkpoint():
        extra_checkpoint_data = checkpointer.load(cfg.MODEL.PRETRAINED_DETECTOR_CKPT, 
                                       update_schedule=cfg.SOLVER.UPDATE_SCHEDULE_DURING_LOAD)
        arguments.update(extra_checkpoint_data)
        if cfg.SOLVER.UPDATE_SCHEDULE_DURING_LOAD:
            checkpointer.scheduler.last_epoch = extra_checkpoint_data["iteration"]
            logger.info("update last epoch of scheduler to iter: {}".format(str(extra_checkpoint_data["iteration"])))
    else:
        # load_mapping is only used when we init current model from detection model.
        checkpointer.load(cfg.MODEL.PRETRAINED_DETECTOR_CKPT, with_optim=False, load_mapping=load_mapping)
    debug_print(logger, 'end load checkpointer')
    train_data_loader = make_data_loader(
        cfg,
        mode='train',
        is_distributed=distributed,
        start_iter=arguments["iteration"],
    )
    val_data_loaders = make_data_loader(
        cfg,
        mode='val',
        is_distributed=distributed,
    )
    debug_print(logger, 'end dataloader')
    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD

    if cfg.SOLVER.PRE_VAL:
        logger.info("Validate before training")
        run_val(cfg, model, val_data_loaders, distributed, logger)

    logger.info("Start training")
    meters = MetricLogger(delimiter="  ")
    max_iter = len(train_data_loader)
    start_iter = arguments["iteration"]
    start_training_time = time.time()
    end = time.time()

    print_first_grad = True
    for iteration, (images, targets, _) in enumerate(train_data_loader, start_iter):
        if any(len(target) < 1 for target in targets):
            logger.error(f"Iteration={iteration + 1} || Image Ids used for training {_} || targets Length={[len(target) for target in targets]}" )
        data_time = time.time() - end
        iteration = iteration + 1
        arguments["iteration"] = iteration

        model.train()
        fix_eval_modules(eval_modules)

        images = images.to(device)
        targets = [target.to(device) for target in targets]

        loss_dict = model(images, targets)

        losses = sum(loss for loss in loss_dict.values())

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = reduce_loss_dict(loss_dict)
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())
        meters.update(loss=losses_reduced, **loss_dict_reduced)

        optimizer.zero_grad()
        # Note: If mixed precision is not used, this ends up doing nothing
        # Otherwise apply loss scaling for mixed-precision recipe
        with amp.scale_loss(losses, optimizer) as scaled_losses:
            scaled_losses.backward()
        
        # add clip_grad_norm from MOTIFS, tracking gradient, used for debug
        verbose = (iteration % cfg.SOLVER.PRINT_GRAD_FREQ) == 0 or print_first_grad # print grad or not
        print_first_grad = False
        clip_grad_norm([(n, p) for n, p in model.named_parameters() if p.requires_grad], max_norm=cfg.SOLVER.GRAD_NORM_CLIP, logger=logger, verbose=verbose, clip=True)

        optimizer.step()

        batch_time = time.time() - end
        end = time.time()
        meters.update(time=batch_time, data=data_time)

        eta_seconds = meters.time.global_avg * (max_iter - iteration)
        eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

        if iteration % 200 == 0 or iteration == max_iter:
            logger.info(
                meters.delimiter.join(
                    [
                        "eta: {eta}",
                        "iter: {iter}",
                        "{meters}",
                        "lr: {lr:.6f}",
                        "max mem: {memory:.0f}",
                    ]
                ).format(
                    eta=eta_string,
                    iter=iteration,
                    meters=str(meters),
                    lr=optimizer.param_groups[-1]["lr"],
                    memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
                )
            )

        if iteration % checkpoint_period == 0:
            checkpointer.save("model_{:07d}".format(iteration), **arguments)
        if iteration == max_iter:
            checkpointer.save("model_final", **arguments)

        val_result = None # used for scheduler updating
        if cfg.SOLVER.TO_VAL and iteration % cfg.SOLVER.VAL_PERIOD == 0:
            logger.info("Start validating")
            val_result = run_val(cfg, model, val_data_loaders, distributed, logger)
            logger.info("Validation Result: %.4f" % val_result)
 
        # scheduler should be called after optimizer.step() in pytorch>=1.1.0
        # https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
        if cfg.SOLVER.SCHEDULE.TYPE == "WarmupReduceLROnPlateau":
            scheduler.step(val_result, epoch=iteration)
            if scheduler.stage_count >= cfg.SOLVER.SCHEDULE.MAX_DECAY_STEP:
                logger.info("Trigger MAX_DECAY_STEP at iteration {}.".format(iteration))
                break
        else:
            scheduler.step()

    total_training_time = time.time() - start_training_time
    total_time_str = str(datetime.timedelta(seconds=total_training_time))
    logger.info(
        "Total training time: {} ({:.4f} s / it)".format(
            total_time_str, total_training_time / (max_iter)
        )
    )
    return model
def do_train(model, data_loader, optimizer, scheduler, checkpointer, device,
             validation_period, checkpoint_period, arguments, run_validation):
    logger = logging.getLogger("maskrcnn_benchmark.trainer")
    logger.info("Start training")
    meters = MetricLogger(delimiter="  ")
    max_iter = len(data_loader)
    start_iter = arguments["iteration"]
    model.train()
    start_training_time = time.time()
    end = time.time()

    saved_models = {}
    best_metric = float("-inf")
    best_model_iter = None

    for iteration, (images, targets, _) in enumerate(data_loader, start_iter):
        data_time = time.time() - end
        batch_start = time.time()
        arguments["iteration"] = iteration

        if iteration % validation_period == 0:
            results = validate_and_log(model, run_validation, iteration)

            first_dataset_results = next(iter(results.items()))
            dataset_name = first_dataset_results[0]
            metric_name = "[email protected]"
            metric = first_dataset_results[1][metric_name]

            if metric > best_metric:
                logger.info(
                    f"Found a new current best model: iter {iteration}, {metric_name} on {dataset_name} = {metric:0.4f}"
                )
                # checkpoint the best model
                best_metric = metric
                best_model_iter = iteration
                model_filename = 'model_best'
                checkpointer.save(model_filename, **arguments)

        if iteration % checkpoint_period == 0:
            model_filename = 'model_{:07d}'.format(iteration)
            checkpointer.save(model_filename, **arguments)
            saved_models[iteration] = model_filename + '.pth'

        model.train()

        scheduler.step()

        images = images.to(device)
        targets = [target.to(device) for target in targets]

        loss_dict = model(images, targets)

        losses = sum(loss for loss in loss_dict.values())

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = reduce_loss_dict(loss_dict)
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())
        meters.update(loss=losses_reduced, **loss_dict_reduced)

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        batch_time = time.time() - batch_start
        end = time.time()
        meters.update(time=batch_time, data=data_time)

        eta_seconds = meters.time.global_avg * (max_iter - iteration)
        eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

        logger.info(
            meters.delimiter.join([
                "eta: {eta}",
                "iter: {iter}",
                "{meters}",
                "lr: {lr:.6f}",
                "max mem: {memory:.0f}",
            ]).format(
                eta=eta_string,
                iter=iteration,
                meters=str(meters),
                lr=optimizer.param_groups[0]["lr"],
                memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
            ))

        losses_str = meters.delimiter.join(
            ["Loss: {:.4f}".format(losses.item())] + [
                "{0}: {1:.4f}".format(k, v.item())
                for k, v in loss_dict_reduced.items()
            ])
        logger.info(losses_str)

    validate_and_log(model, run_validation, arguments["iteration"])
    checkpointer.save("model_final", **arguments)

    if max_iter > 0:
        total_training_time = time.time() - start_training_time
        total_time_str = str(datetime.timedelta(seconds=total_training_time))
        logger.info("Total training time: {} ({:.4f} s / it)".format(
            total_time_str, total_training_time / (max_iter)))
Example #7
0
def do_train(model,
             data_loader,
             optimizer,
             scheduler,
             checkpointer,
             device,
             checkpoint_period,
             arguments,
             logger,
             tensorboard_writer: TensorboardWriter = None):
    logger.info("Start training")
    meters = MetricLogger(delimiter="  ")
    max_iter = len(data_loader)
    start_iter = arguments["iteration"]
    model.train()
    start_training_time = time.time()
    end = time.time()

    for iteration, (images, targets, _) in enumerate(data_loader, start_iter):

        if any(len(target) < 1 for target in targets):
            logger.error(
                "Iteration={iteration + 1} || Image Ids used for training {_} || "
                "targets Length={[len(target) for target in targets]}")
            continue

        data_time = time.time() - end
        iteration = iteration + 1
        arguments["iteration"] = iteration

        scheduler.step()

        images = images.to(device)
        targets = [target.to(device) for target in targets]

        result, loss_dict = model(images, targets)

        losses = sum(loss for loss in loss_dict.values())

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = reduce_loss_dict(loss_dict)
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())
        meters.update(loss=losses_reduced, **loss_dict_reduced)

        optimizer.zero_grad()
        # Note: If mixed precision is not used, this ends up doing nothing
        # Otherwise apply loss scaling for mixed-precision recipe
        with amp.scale_loss(losses, optimizer) as scaled_losses:
            scaled_losses.backward()
        optimizer.step()

        # write images / ground truth / evaluation metrics to tensorboard
        tensorboard_writer(iteration, losses_reduced, loss_dict_reduced,
                           images, targets)

        batch_time = time.time() - end
        end = time.time()
        meters.update(time=batch_time, data=data_time)
        eta_seconds = meters.time.global_avg * (max_iter - iteration)
        eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

        if get_world_size() < 2 or dist.get_rank() == 0:
            if iteration % 20 == 0 or iteration == max_iter:
                logger.info(
                    meters.delimiter.join([
                        "eta: {eta}",
                        "iter: {iter}",
                        "{meters}",
                        "lr: {lr:.6f}",
                    ]).format(
                        eta=eta_string,
                        iter=iteration,
                        meters=str(meters),
                        lr=optimizer.param_groups[0]["lr"],
                    ))
        if iteration % checkpoint_period == 0:
            checkpointer.save("model_{:07d}".format(iteration), **arguments)
        if iteration == max_iter:
            checkpointer.save("model_final", **arguments)

    total_training_time = time.time() - start_training_time
    total_time_str = str(datetime.timedelta(seconds=total_training_time))
    logger.info("Total training time: {} ({:.4f} s / it)".format(
        total_time_str, total_training_time / (max_iter)))