コード例 #1
0
ファイル: train.py プロジェクト: zbxzc35/Graph2plan
    def test(engine,batch):
        model.eval()
        with torch.no_grad():
            boundary,inside_box,objs,attrs,triples,layout,boxes,inside_coords,obj_to_img,triple_to_img,name = batch_cuda(batch)

            model_out = model(
                objs, 
                triples, 
                boundary,
                obj_to_img = obj_to_img,
                attributes=attrs,
                boxes_gt= boxes if args.gt_box else None, 
                generate = args.gene_layout,
                refine = args.box_refine,
                relative = args.relative,
                inside_box=inside_box if args.relative else None,
            )
            boxes_pred, gene_layout, boxes_refine = model_out

            ''' box: x_c,y_c,w,h -> x0,y0,x1,y1 '''
            # boxes pred
            boxes_pred = boxes_pred.detach()
            boxes_pred = centers_to_extents(boxes_pred)
                
            # boxes refine
            if args.box_refine:
                boxes_refine = boxes_refine.detach()
                boxes_refine = centers_to_extents(boxes_refine)
                
            # gt
            if args.relative: boxes = box_rel2abs(boxes,inside_box,obj_to_img)
            boxes = centers_to_extents(boxes)

            ''' layout: B*C*H*W->B*H*W '''
            if args.gene_layout: 
                gene_layout = gene_layout*boundary[:,:1]
                gene_preds = torch.argmax(gene_layout.softmax(1).detach(),dim=1)
            
            ''' layout with outside'''
            for i in range(len(layout)):
                mask = boundary[i,0]==0
                if args.gene_layout: 
                    gene_preds[i][mask]=13

            ''' mertics '''
            # box iou
            box_ious = iou(boxes_pred,boxes)
            box_refine_ious = None
            if args.box_refine:
                box_refine_ious = iou(boxes_refine,boxes)

            gene_acc_all = None
            gene_acc_fg = None
            if args.gene_layout: 
                gene_acc_all = image_acc(gene_preds,layout)
                gene_acc_fg = image_acc_ignore(gene_preds,layout,13)

            ''' save output '''
            for i in range(len(layout)):
                ''' objs '''
                obj = objs[obj_to_img==i].cpu().numpy()

                ''' box '''
                box_pred = boxes_pred[obj_to_img==i]
                box_pred = box_pred.cpu().numpy()
                box_iou = box_ious[obj_to_img==i].view(-1).cpu().numpy()
                
                box_refine = None
                if args.box_refine:
                    box_refine = boxes_refine[obj_to_img==i].cpu().numpy()
                    box_refine_iou = box_refine_ious[obj_to_img==i].view(-1).cpu().numpy()

                ''' layout '''
                if args.gene_layout: 
                    gene_pred = gene_preds[i].cpu().numpy().astype('uint8')
  

                output[name[i]] = {
                        'obj':obj,
                        'box_gt':boxes[obj_to_img==i].cpu().numpy(),
                    
                        'box_pred':box_pred,
                        'box_iou':box_iou,

                        'box_refine':box_refine if args.box_refine else None,
                        'box_refine_iou':box_refine_iou if args.box_refine else None,

                        'gene_pred':gene_pred if args.gene_layout else None,
                        'gene_acc_all': gene_acc_all[i].item() if args.gene_layout else None,
                        'gene_acc_fg':gene_acc_fg[i].item() if args.gene_layout else None
                        }
            return {
                'pred':[
                    boxes_pred,#0
                    gene_preds if args.gene_layout else None,#1
                    boxes_refine if args.box_refine else None,#2
                    ],
                'gt':[layout,boxes]
            }                   
コード例 #2
0
ファイル: test_metrics.py プロジェクト: kqf/ds-bowl-2018
def test_calculates_metrics(true_labels, logits):
    y_true = true_labels.astype(int)
    assert iou(y_true, logits) == 0.0
コード例 #3
0
ファイル: train.py プロジェクト: zbxzc35/Graph2plan
def main(args):
    args.epoch=args.epoch if not args.debug else 6
    print("Create dir...")
    start_date = str(datetime.datetime.now().strftime('%Y-%m-%d'))+("" if not args.debug else "_debug")+("" if not args.test else "_test")
    if not os.path.exists(f'../experiment'):
        os.mkdir(f'../experiment')
    experiment_dir = path.Path(f'../experiment/{start_date}')
    experiment_dir.mkdir(exist_ok=True)
    start_time = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')) + '' if args.suffix is None else args.suffix
    file_dir = path.Path(f'{experiment_dir}/DeepLayout_{start_time}')
    file_dir.mkdir(exist_ok=True)
    checkpoints_dir = file_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    log_dir = file_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)
    shutil.copy(__file__,log_dir/'train.py')
    shutil.copytree('./model',log_dir/'model')
    output_dir = file_dir.joinpath('output/')
    output_dir.mkdir(exist_ok=True)
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler(str(log_dir)+'/log.txt')
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    if args.skip_train:
        logger.info(f'python {args.argv}')
    else: 
        logger.info(f'python {args.argv} --skip_train 1 --pretrain ')
    logger.info(args)
    logger.info('---------------------------------------------------TRANING---------------------------------------------------')
    logger.info(f'Use seed: {args.seed}')
    # check_manual_seed(args)

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu if args.multi_gpu is None else args.multi_gpu

    print("Create dataloader...")
    train_loader,valid_loader,test_loader = get_data_loaders(args)
    print("Create model...")
    model = get_model(args)
    print("Gene:",model.refinement_net!=None and args.gene_layout)
    print("Refine:",args.box_refine)
    print("Cat feat:",args.roi_cat_feature)
    print("GT BOX:",args.gt_box)
    print("Iniside Loss:",args.inside)
    print("Coverage Loss:",args.coverage)
    print("Mutex Loss:",args.mutex)
    print("Render Loss:",args.render)
    logger.info(argparse.Namespace(embedding_dim=args.embedding_dim,
    image_size=args.image_size,
    input_dim = args.input_dim,
    attribute_dim=args.pos_dim+args.area_dim,
    refinement_dims=args.refinement_dims if args.gene_layout else None,    
    box_refine_arch=args.box_refine_arch if args.box_refine else None,
    roi_cat_feature=args.roi_cat_feature))
    logger.info(str(model))
    optimizer = get_optimizer(model,args)
    scheduler = get_scheduler(optimizer,args)
    loss = get_losses(args)

    if args.pretrain is not None:
        model.load_state_dict(torch.load(args.pretrain))

    print("Cuda...")
    model.cuda()

    def update(engine,batch):
        model.train()
        optimizer.zero_grad()
        
        boundary,inside_box,objs,attrs,triples,layout,boxes,inside_coords,obj_to_img,triple_to_img,name = batch_cuda(batch)

        if args.relative: boxes = box_rel2abs(boxes,inside_box,obj_to_img)

        model_out = model(
            objs, 
            triples, 
            boundary,
            obj_to_img = obj_to_img,
            attributes=attrs,
            boxes_gt= boxes if args.gt_box else None, 
            generate = args.gene_layout and engine.state.epoch>1,
            refine = args.box_refine and engine.state.epoch>2,
            relative = args.relative,
            inside_box=inside_box if args.relative else None,
        )
        boxes_pred, gene_layout, boxes_refine = model_out
        
        total_loss = 0
        loss_items = {}
        epoch = engine.state.epoch
        step_weight = [0.1,0.5,1.0]
        for name in loss:
            l = None
            if name=='box_mse':
                l = loss[name](boxes_pred,boxes)
            else:
                if epoch>1:
                    if name=='gene_ce':
                        l = step_weight[epoch-2 if epoch<=3 else -1]*loss[name](gene_layout,layout)
                    elif name=='mutex':
                        l = 0.1*loss[name](boxes_pred,obj_to_img,objs)
                        if args.box_refine and args.loss_refine and epoch>2: l+=loss[name](boxes_refine,obj_to_img,objs)
                    elif name=='inside':
                        l = 0.1*loss[name](boxes_pred,inside_box,obj_to_img)
                        if args.box_refine and args.loss_refine and epoch>2: l+=loss[name](boxes_refine,inside_box,obj_to_img)
                    elif name=='coverage':
                        l = 0.1*loss[name](boxes_pred,inside_coords,obj_to_img)
                        if args.box_refine and args.loss_refine and epoch>2: l+=loss[name](boxes_refine,inside_coords,obj_to_img)
                    elif name=='render':
                        l = loss[name](boxes_pred,boxes)
                        if args.box_refine and args.loss_refine and epoch>2: l+=loss[name](boxes_refine,boxes)
                
                if epoch>2:
                    if name=='box_ref_mse':
                        l = step_weight[epoch-3 if epoch<=4 else -1]*loss[name](boxes_refine,boxes)

            if l is not None:
                total_loss+=l
                loss_items[name]=l.item()
        loss_items['total_loss'] = total_loss.item()

        total_loss.backward()
        optimizer.step()
        return loss_items
    
    def inference(engine,batch):
        model.eval()
        with torch.no_grad():
            boundary,inside_box,objs,attrs,triples,layout,boxes,inside_coords,obj_to_img,triple_to_img,name = batch_cuda(batch)

            if args.relative: boxes = box_rel2abs(boxes,inside_box,obj_to_img)

            model_out = model(
                objs, 
                triples, 
                boundary,
                obj_to_img = obj_to_img,
                attributes=attrs,
                boxes_gt= boxes if args.gt_box else None, 
                generate = args.gene_layout,
                refine = args.box_refine,
                relative = args.relative,
                inside_box=inside_box if args.relative else None,
            )
            boxes_pred, gene_layout, boxes_refine = model_out
            
            total_loss = 0
            loss_items = {}
            for name in loss:
                l = None
                if name=='box_mse':
                    l = loss[name](boxes_pred,boxes)
                if engine.state.epoch>1:
                    if name=='gene_ce':
                        l = loss[name](gene_layout,layout)
                    elif name=='mutex':
                        l = 0.1*loss[name](boxes_pred,obj_to_img,objs)
                        if args.box_refine and args.loss_refine: l+=0.1*loss[name](boxes_refine,obj_to_img,objs)
                    elif name=='inside':
                        l = 0.1*loss[name](boxes_pred,inside_box,obj_to_img)
                        if args.box_refine and args.loss_refine: l+=0.1*loss[name](boxes_refine,inside_box,obj_to_img)
                    elif name=='coverage':
                        l = 0.1*loss[name](boxes_pred,inside_coords,obj_to_img)
                        if args.box_refine and args.loss_refine: l+=0.1*loss[name](boxes_refine,inside_coords,obj_to_img)
                    elif name=='render':
                        l = loss[name](boxes_pred,boxes)
                        if args.box_refine and args.loss_refine: l+=loss[name](boxes_refine,boxes)

                if engine.state.epoch>2:
                    if name=='box_ref_mse':
                        l = loss[name](boxes_refine,boxes)

                if l is not None:
                    total_loss+=l
                    loss_items[name]=l.item()
            loss_items['total_loss'] = total_loss.item()

            # boxes pred
            boxes_pred = boxes_pred.detach()
            boxes_pred = centers_to_extents(boxes_pred)

            if args.gene_layout:
                gene_layout = gene_layout*boundary[:,:1]

            # boxes refine
            if args.box_refine:
                boxes_refine = boxes_refine.detach()
                boxes_refine = centers_to_extents(boxes_refine)
                
            # gt
            boxes = centers_to_extents(boxes)

            return {
                'loss':loss_items,
                'pred':[
                    boxes_pred,
                    gene_layout.detach() if args.gene_layout else None,
                    boxes_refine if args.box_refine else None,
                    ],
                'gt':[layout,boxes]
            }
    
    print("Create trainer...")
    optimizer.step()
    scheduler.step(0)
    trainer = Engine(update)
    valid_evaluator = Engine(inference)

    if args.start_epoch is not None:
        @trainer.on(Events.STARTED)
        def set_up_state(engine):
            engine.state.epoch = args.start_epoch

    total_func = lambda e:(e.state.metrics['box_iou']+(e.state.metrics['gene_acc'] if args.gene_layout else 0)+(e.state.metrics['box_refine_iou'] if args.box_refine else 0))

    @valid_evaluator.on(Events.COMPLETED)
    def schedual(engine):
        optimizer.step()
        if args.scheduler == 'step':
            scheduler.step()
        else:
            scheduler.step(total_func(engine))

    @trainer.on(Events.EPOCH_COMPLETED)
    def evaluate(engine):
        valid_evaluator.run(valid_loader)

    # Metrics
    MetricAverage(output_transform=lambda output:iou(output['pred'][0],output['gt'][1])).attach(valid_evaluator,'box_iou')
    if args.gene_layout: 
        MetricAverage(output_transform=lambda output:image_acc_ignore(output['pred'][1],output['gt'][0],13)).attach(valid_evaluator,'gene_acc')
    if args.box_refine: 
        MetricAverage(output_transform=lambda output:iou(output['pred'][2],output['gt'][1])).attach(valid_evaluator,'box_refine_iou')
    
    metrics = ['img_acc','box_iou','mask_acc']

    # TQDM
    ProgressBar(persist=True).attach(trainer, output_transform=lambda o:{'loss':o['total_loss']}, metric_names='all')
    ProgressBar(persist=False).attach(valid_evaluator, output_transform=lambda o:{'loss':o['loss']['total_loss']},metric_names='all')
    
    # Tensorboard 
    tb_logger = TensorboardLogger(log_dir=log_dir)
    tb_logger.attach(trainer,
                 log_handler=OutputHandler(tag="train",output_transform=lambda o: o,metric_names='all'),
                 event_name=Events.ITERATION_COMPLETED)
    tb_logger.attach(trainer,
                 log_handler=OptimizerParamsHandler(optimizer),
                 event_name=Events.ITERATION_STARTED)
    tb_logger.attach(valid_evaluator,
                 log_handler=OutputHandler(tag="valid",output_transform=lambda o:o['loss'],metric_names='all', global_step_transform=global_step_from_engine(trainer)),
                 event_name=Events.EPOCH_COMPLETED)
    
    # Logging
    @trainer.on(Events.EPOCH_COMPLETED)
    def log_results(engine):
        logging.info(f'Train, Epoch{engine.state.epoch}, Loss: {str(engine.state.output)}')

    @valid_evaluator.on(Events.EPOCH_COMPLETED)
    def log_results(engine):
        loss = engine.state.output['loss']
        metrics = engine.state.metrics
        logging.info(f'Valid, Epoch{engine.state.epoch}, Loss: {str(loss)}')
        logging.info(f'Valid, Epoch{engine.state.epoch}, Metrics: {str(metrics)}')
    
    # Checkpoint
    epoch_saver = ModelCheckpoint(checkpoints_dir, 'epoch',save_interval=args.save_interval,n_saved=args.n_saved, require_empty=False, create_dir=True)
    latest_saver = ModelCheckpoint(checkpoints_dir, 'latest',score_function=lambda e:e.state.epoch,n_saved=1, require_empty=False, create_dir=True)
    loss_saver = ModelCheckpoint(checkpoints_dir, 'loss',score_function=lambda e:-e.state.output['loss']['total_loss'],n_saved=1, require_empty=False, create_dir=True)

    trainer.add_event_handler(Events.EPOCH_COMPLETED, latest_saver, {'model': model,'opt':optimizer})
    trainer.add_event_handler(Events.EPOCH_COMPLETED, epoch_saver, {'model': model,'opt':optimizer})
    valid_evaluator.add_event_handler(Events.COMPLETED, loss_saver, {'model': model})

    if not args.skip_train:
        trainer.run(train_loader,max_epochs=args.epoch)
    tb_logger.close()

    output = {}
    def test(engine,batch):
        model.eval()
        with torch.no_grad():
            boundary,inside_box,objs,attrs,triples,layout,boxes,inside_coords,obj_to_img,triple_to_img,name = batch_cuda(batch)

            model_out = model(
                objs, 
                triples, 
                boundary,
                obj_to_img = obj_to_img,
                attributes=attrs,
                boxes_gt= boxes if args.gt_box else None, 
                generate = args.gene_layout,
                refine = args.box_refine,
                relative = args.relative,
                inside_box=inside_box if args.relative else None,
            )
            boxes_pred, gene_layout, boxes_refine = model_out

            ''' box: x_c,y_c,w,h -> x0,y0,x1,y1 '''
            # boxes pred
            boxes_pred = boxes_pred.detach()
            boxes_pred = centers_to_extents(boxes_pred)
                
            # boxes refine
            if args.box_refine:
                boxes_refine = boxes_refine.detach()
                boxes_refine = centers_to_extents(boxes_refine)
                
            # gt
            if args.relative: boxes = box_rel2abs(boxes,inside_box,obj_to_img)
            boxes = centers_to_extents(boxes)

            ''' layout: B*C*H*W->B*H*W '''
            if args.gene_layout: 
                gene_layout = gene_layout*boundary[:,:1]
                gene_preds = torch.argmax(gene_layout.softmax(1).detach(),dim=1)
            
            ''' layout with outside'''
            for i in range(len(layout)):
                mask = boundary[i,0]==0
                if args.gene_layout: 
                    gene_preds[i][mask]=13

            ''' mertics '''
            # box iou
            box_ious = iou(boxes_pred,boxes)
            box_refine_ious = None
            if args.box_refine:
                box_refine_ious = iou(boxes_refine,boxes)

            gene_acc_all = None
            gene_acc_fg = None
            if args.gene_layout: 
                gene_acc_all = image_acc(gene_preds,layout)
                gene_acc_fg = image_acc_ignore(gene_preds,layout,13)

            ''' save output '''
            for i in range(len(layout)):
                ''' objs '''
                obj = objs[obj_to_img==i].cpu().numpy()

                ''' box '''
                box_pred = boxes_pred[obj_to_img==i]
                box_pred = box_pred.cpu().numpy()
                box_iou = box_ious[obj_to_img==i].view(-1).cpu().numpy()
                
                box_refine = None
                if args.box_refine:
                    box_refine = boxes_refine[obj_to_img==i].cpu().numpy()
                    box_refine_iou = box_refine_ious[obj_to_img==i].view(-1).cpu().numpy()

                ''' layout '''
                if args.gene_layout: 
                    gene_pred = gene_preds[i].cpu().numpy().astype('uint8')
  

                output[name[i]] = {
                        'obj':obj,
                        'box_gt':boxes[obj_to_img==i].cpu().numpy(),
                    
                        'box_pred':box_pred,
                        'box_iou':box_iou,

                        'box_refine':box_refine if args.box_refine else None,
                        'box_refine_iou':box_refine_iou if args.box_refine else None,

                        'gene_pred':gene_pred if args.gene_layout else None,
                        'gene_acc_all': gene_acc_all[i].item() if args.gene_layout else None,
                        'gene_acc_fg':gene_acc_fg[i].item() if args.gene_layout else None
                        }
            return {
                'pred':[
                    boxes_pred,#0
                    gene_preds if args.gene_layout else None,#1
                    boxes_refine if args.box_refine else None,#2
                    ],
                'gt':[layout,boxes]
            }                   
    
    test_evaluator = Engine(test)

    MetricAverage(output_transform=lambda output:iou(output['pred'][0],output['gt'][1])).attach(test_evaluator,'box_iou')

    if args.gene_layout: 
        MetricAverage(output_transform=lambda output:image_acc_ignore(output['pred'][1],output['gt'][0],13)).attach(test_evaluator,'gene_acc')
        MetricAverage(output_transform=lambda output:image_acc(output['pred'][1],output['gt'][0])).attach(test_evaluator,'gene_acc_all')
    if args.box_refine: 
        MetricAverage(output_transform=lambda output:iou(output['pred'][2],output['gt'][1])).attach(test_evaluator,'box_refine_iou')

    ProgressBar(persist=False).attach(test_evaluator)
    @test_evaluator.on(Events.COMPLETED)
    def save_metrics(engine):
        metrics = engine.state.metrics
        with open(f'{output_dir}/output_{start_time}_metrics.json','w') as f:
            f.write(str(metrics))

    if not args.skip_train:
        test_evaluator.run(valid_loader)
    else:
        test_evaluator.run(test_loader)
    with open(f'{output_dir}/output_{start_time}.pkl','wb') as f:
        pickle.dump(output,f,pickle.HIGHEST_PROTOCOL)