Example #1
0
def epoch_train(model, dataloader, dataset, criterion, optimizer, scheduler,
                device, data_const):
    print('epoch training...')

    # set visualization and create folder to save checkpoints
    writer = SummaryWriter(log_dir=args.log_dir + '/' + args.exp_ver + '/' +
                           'epoch_train')
    io.mkdir_if_not_exists(os.path.join(args.save_dir, args.exp_ver,
                                        'epoch_train'),
                           recursive=True)

    for epoch in range(args.start_epoch, args.epoch):
        # each epoch has a training and validation step
        epoch_loss = 0
        for phase in ['train', 'val']:
            start_time = time.time()
            running_loss = 0
            # all_edge = 0
            idx = 0

            VcocoDataset.data_sample_count = 0
            for data in tqdm(dataloader[phase]):
                train_data = data
                img_name = train_data['img_name']
                det_boxes = train_data['det_boxes']
                roi_labels = train_data['roi_labels']
                roi_scores = train_data['roi_scores']
                node_num = train_data['node_num']
                edge_labels = train_data['edge_labels']
                edge_num = train_data['edge_num']
                features = train_data['features']
                spatial_feat = train_data['spatial_feat']
                word2vec = train_data['word2vec']
                features, spatial_feat, word2vec, edge_labels = features.to(
                    device), spatial_feat.to(device), word2vec.to(
                        device), edge_labels.to(device)
                if idx == 10: break
                if phase == 'train':
                    model.train()
                    model.zero_grad()
                    outputs = model(node_num, features, spatial_feat, word2vec,
                                    roi_labels)
                    # import ipdb; ipdb.set_trace()
                    loss = criterion(outputs, edge_labels.float())
                    loss.backward()
                    optimizer.step()

                else:
                    model.eval()
                    # turn off the gradients for validation, save memory and computations
                    with torch.no_grad():
                        outputs = model(node_num,
                                        features,
                                        spatial_feat,
                                        word2vec,
                                        roi_labels,
                                        validation=True)
                        loss = criterion(outputs, edge_labels.float())
                    # print result every 1000 iteration during validation
                    if idx == 0 or idx % round(
                            1000 / args.batch_size) == round(
                                1000 / args.batch_size) - 1:
                        # ipdb.set_trace()
                        image = Image.open(
                            os.path.join(
                                data_const.original_image_dir,
                                img_name[0][:].astype(
                                    np.uint8).tostring().decode(
                                        'ascii'))).convert('RGB')
                        image_temp = image.copy()
                        raw_outputs = nn.Sigmoid()(outputs[0:int(edge_num[0])])
                        raw_outputs = raw_outputs.cpu().detach().numpy()
                        # class_img = vis_img(image, det_boxes, roi_labels, roi_scores)
                        class_img = vis_img_vcoco(
                            image,
                            det_boxes[0],
                            roi_labels[0],
                            roi_scores[0],
                            edge_labels[0:int(edge_num[0])].cpu().numpy(),
                            score_thresh=0.7)
                        action_img = vis_img_vcoco(image_temp,
                                                   det_boxes[0],
                                                   roi_labels[0],
                                                   roi_scores[0],
                                                   raw_outputs,
                                                   score_thresh=0.5)
                        writer.add_image(
                            'gt_detection',
                            np.array(class_img).transpose(2, 0, 1))
                        writer.add_image(
                            'action_detection',
                            np.array(action_img).transpose(2, 0, 1))
                        writer.add_text(
                            'img_name', img_name[0][:].astype(
                                np.uint8).tostring().decode('ascii'), epoch)

                idx += 1
                # accumulate loss of each batch
                running_loss += loss.item() * edge_labels.shape[0]
                # all_edge += edge_labels.shape[0]
            # calculate the loss and accuracy of each epoch
            epoch_loss = running_loss / len(dataset[phase])
            # epoch_loss = running_loss / all_edge
            # import ipdb; ipdb.set_trace()
            # log trainval datas, and visualize them in the same graph
            if phase == 'train':
                train_loss = epoch_loss
                VcocoDataset.displaycount()
            else:
                writer.add_scalars('trainval_loss_epoch', {
                    'train': train_loss,
                    'val': epoch_loss
                }, epoch)
            # print data
            if (epoch % args.print_every) == 0:
                end_time = time.time()
                print("[{}] Epoch: {}/{} Loss: {} Execution time: {}".format(\
                        phase, epoch+1, args.epoch, epoch_loss, (end_time-start_time)))

        # scheduler.step()
        # save model
        if epoch_loss < 0.0405 or epoch % args.save_every == (
                args.save_every - 1) and epoch >= (500 - 1):
            checkpoint = {
                'lr': args.lr,
                'b_s': args.batch_size,
                'bias': args.bias,
                'bn': args.bn,
                'dropout': args.drop_prob,
                'layers': args.layers,
                'feat_type': args.feat_type,
                'multi_head': args.multi_attn,
                'diff_edge': args.diff_edge,
                'state_dict': model.state_dict()
            }
            save_name = "checkpoint_" + str(epoch + 1) + '_epoch.pth'
            torch.save(
                checkpoint,
                os.path.join(args.save_dir, args.exp_ver, 'epoch_train',
                             save_name))

    writer.close()
    print('Finishing training!')
Example #2
0
def run_model(args, data_const):
    # set up dataset variable
    train_dataset = VcocoDataset(data_const=data_const,
                                 subset='vcoco_train',
                                 data_aug=args.data_aug,
                                 sampler=args.sampler)
    val_dataset = VcocoDataset(data_const=data_const,
                               subset='vcoco_val',
                               data_aug=False,
                               sampler=args.sampler)
    dataset = {'train': train_dataset, 'val': val_dataset}
    print('set up dataset variable successfully')
    # use default DataLoader() to load the data.
    train_dataloader = DataLoader(dataset=dataset['train'],
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  collate_fn=collate_fn)
    val_dataloader = DataLoader(dataset=dataset['val'],
                                batch_size=args.batch_size,
                                shuffle=True,
                                collate_fn=collate_fn)
    dataloader = {'train': train_dataloader, 'val': val_dataloader}
    print('set up dataloader successfully')

    device = torch.device(
        'cuda' if torch.cuda.is_available() and args.gpu else 'cpu')
    print('training on {}...'.format(device))

    model = AGRNN(feat_type=args.feat_type,
                  bias=args.bias,
                  bn=args.bn,
                  dropout=args.drop_prob,
                  multi_attn=args.multi_attn,
                  layer=args.layers,
                  diff_edge=args.diff_edge,
                  HICO=args.hico)

    # load pretrained model of HICO_DET dataset
    if args.hico:
        print(f"loading pretrained model of HICO_DET dataset {args.hico}")
        checkpoints = torch.load(args.hico, map_location=device)
        # import ipdb; ipdb.set_trace()
        model.load_state_dict(checkpoints['state_dict'])
        # change the last layer 117->24
        model.edge_readout.classifier.layers[1] = Predictor(
            model.CONFIG1).classifier.layers[1]

    # calculate the amount of all the learned parameters
    parameter_num = 0
    for param in model.parameters():
        parameter_num += param.numel()
    print(
        f'The parameters number of the model is {parameter_num / 1e6} million')

    # load pretrained model
    if args.pretrained:
        print(f"loading pretrained model {args.pretrained}")
        checkpoints = torch.load(args.pretrained, map_location=device)
        model.load_state_dict(checkpoints['state_dict'])
    model.to(device)
    # # build optimizer && criterion
    if args.optim == 'sgd':
        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              momentum=0.9,
                              weight_decay=0)
    else:
        optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=0)
    # ipdb.set_trace()
    # criterion = nn.MultiLabelSoftMarginLoss()
    criterion = nn.BCEWithLogitsLoss()
    scheduler = optim.lr_scheduler.StepLR(
        optimizer, step_size=500,
        gamma=0.3)  #the scheduler divides the lr by 10 every 150 epochs

    # get the configuration of the model and save some key configurations
    io.mkdir_if_not_exists(os.path.join(args.save_dir, args.exp_ver),
                           recursive=True)
    for i in range(args.layers):
        if i == 0:
            model_config = model.CONFIG1.save_config()
            model_config['lr'] = args.lr
            model_config['bs'] = args.batch_size
            model_config['layers'] = args.layers
            model_config['multi_attn'] = args.multi_attn
            model_config['data_aug'] = args.data_aug
            model_config['drop_out'] = args.drop_prob
            model_config['optimizer'] = args.optim
            model_config['diff_edge'] = args.diff_edge
            model_config['model_parameters'] = parameter_num
            io.dump_json_object(
                model_config,
                os.path.join(args.save_dir, args.exp_ver, 'l1_config.json'))
        elif i == 1:
            model_config = model.CONFIG2.save_config()
            io.dump_json_object(
                model_config,
                os.path.join(args.save_dir, args.exp_ver, 'l2_config.json'))
        else:
            model_config = model.CONFIG3.save_config()
            io.dump_json_object(
                model_config,
                os.path.join(args.save_dir, args.exp_ver, 'l3_config.json'))
    print('save key configurations successfully...')

    if args.train_model == 'epoch':
        epoch_train(model, dataloader, dataset, criterion, optimizer,
                    scheduler, device, data_const)
    else:
        iteration_train(model, dataloader, dataset, criterion, optimizer,
                        scheduler, device, data_const)
Example #3
0
def main(args):
    # Load checkpoint and set up model
    try:
        # use GPU if available else revert to CPU
        device = torch.device(
            'cuda:0' if torch.cuda.is_available() and args.gpu else 'cpu')
        print("Testing on", device)

        # set up model and initialize it with uploaded checkpoint
        if args.dataset == 'hico':
            # load checkpoint
            checkpoint = torch.load(args.main_pretrained_hico,
                                    map_location=device)
            print('vsgats Checkpoint loaded!')
            pg_checkpoint = torch.load(args.pretrained_hico,
                                       map_location=device)
            data_const = HicoConstants(feat_type=checkpoint['feat_type'])
            vs_gats = vsgat_hico(feat_type=checkpoint['feat_type'],
                                 bias=checkpoint['bias'],
                                 bn=checkpoint['bn'],
                                 dropout=checkpoint['dropout'],
                                 multi_attn=checkpoint['multi_head'],
                                 layer=checkpoint['layers'],
                                 diff_edge=checkpoint['diff_edge'])  #2 )
        if args.dataset == 'vcoco':
            # load checkpoint
            checkpoint = torch.load(args.main_pretrained_vcoco,
                                    map_location=device)
            print('vsgats Checkpoint loaded!')
            pg_checkpoint = torch.load(args.pretrained_vcoco,
                                       map_location=device)
            data_const = VcocoConstants()
            vs_gats = vsgat_vcoco(feat_type=checkpoint['feat_type'],
                                  bias=checkpoint['bias'],
                                  bn=checkpoint['bn'],
                                  dropout=checkpoint['dropout'],
                                  multi_attn=checkpoint['multi_head'],
                                  layer=checkpoint['layers'],
                                  diff_edge=checkpoint['diff_edge'])  #2 )
        vs_gats.load_state_dict(checkpoint['state_dict'])
        vs_gats.to(device)
        vs_gats.eval()

        print(pg_checkpoint['o_c_l'], pg_checkpoint['b_l'],
              pg_checkpoint['attn'], pg_checkpoint['lr'],
              pg_checkpoint['dropout'])
        # pgception = PGception(action_num=24, classifier_mod='cat', o_c_l=[64,64,128,128], last_h_c=256, bias=pg_checkpoint['bias'], drop=pg_checkpoint['dropout'], bn=pg_checkpoint['bn'])
        pgception = PGception(action_num=pg_checkpoint['a_n'],
                              layers=1,
                              classifier_mod=pg_checkpoint['classifier_mod'],
                              o_c_l=pg_checkpoint['o_c_l'],
                              last_h_c=pg_checkpoint['last_h_c'],
                              bias=pg_checkpoint['bias'],
                              drop=pg_checkpoint['dropout'],
                              bn=pg_checkpoint['bn'],
                              agg_first=pg_checkpoint['agg_first'],
                              attn=pg_checkpoint['attn'],
                              b_l=pg_checkpoint['b_l'])
        # pgception = PGception(action_num=pg_checkpoint['a_n'], drop=pg_checkpoint['dropout'])
        pgception.load_state_dict(pg_checkpoint['state_dict'])
        pgception.to(device)
        pgception.eval()
        print('Constructed model successfully!')
    except Exception as e:
        print('Failed to load checkpoint or construct model!', e)
        sys.exit(1)

    # prepare for data
    if args.dataset == 'hico':
        original_imgs_dir = os.path.join(data_const.infer_dir,
                                         'original_imgs/hico')
        # original_imgs_dir = './datasets/hico/images/test2015'
        save_path = os.path.join(data_const.infer_dir, 'processed_imgs/hico')
        test_dataset = HicoDataset(data_const=data_const, subset='test')
        dataloader = sorted(os.listdir(original_imgs_dir))
        # dataloader = ['HICO_test2015_00000128.jpg']
    else:
        original_imgs_dir = os.path.join(data_const.infer_dir,
                                         'original_imgs/vcoco')
        # original_imgs_dir = './datasets/vcoco/coco/images/val2014'
        save_path = os.path.join(data_const.infer_dir, 'processed_imgs/vcoco')
        test_dataset = VcocoDataset(data_const=data_const,
                                    subset='vcoco_test',
                                    pg_only=False)
        # dataloader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, collate_fn=vcoco_collate_fn)
        dataloader = sorted(os.listdir(original_imgs_dir))
        dataloader = ['COCO_val2014_000000150361.jpg']

    if not os.path.exists(original_imgs_dir):
        os.makedirs(original_imgs_dir)
    if not os.path.exists(save_path):
        os.mkdir(save_path)
        print('result images will be kept here{}'.format(save_path))

    # ipdb.set_trace()
    for data in tqdm(dataloader):
        # load corresponding data
        # print("Testing on image named {}".format(img))
        if args.dataset == 'hico':
            img = data
            global_id = data.split('.')[0]
            test_data = test_dataset.sample_date(global_id)
            test_data = collate_fn([test_data])
            det_boxes = test_data['det_boxes'][0]
            roi_scores = test_data['roi_scores'][0]
            roi_labels = test_data['roi_labels'][0]
            keypoints = test_data['keypoints'][0]
            edge_labels = test_data['edge_labels']
            node_num = test_data['node_num']
            features = test_data['features']
            spatial_feat = test_data['spatial_feat']
            word2vec = test_data['word2vec']
            pose_normalized = test_data["pose_to_human"]
            pose_to_obj_offset = test_data["pose_to_obj_offset"]
        else:
            # global_id = data['global_id'][0]
            img = data
            global_id = str(int((data.split('.')[0].split('_')[-1])))
            test_data = test_dataset.sample_date(global_id)
            test_data = vcoco_collate_fn([test_data])
            # img = data['img_name'][0][:].astype(np.uint8).tostring().decode('ascii').split("/")[-1]
            # test_data = data
            det_boxes = test_data['det_boxes'][0]
            roi_scores = test_data['roi_scores'][0]
            roi_labels = test_data['roi_labels'][0]
            edge_labels = test_data['edge_labels']
            node_num = test_data['node_num']
            features = test_data['features']
            spatial_feat = test_data['spatial_feat']
            word2vec = test_data['word2vec']
            pose_normalized = test_data["pose_to_human"]
            pose_to_obj_offset = test_data["pose_to_obj_offset"]

        # inference
        pose_to_obj_offset, pose_normalized, features, spatial_feat, word2vec = pose_to_obj_offset.to(
            device), pose_normalized.to(device), features.to(
                device), spatial_feat.to(device), word2vec.to(device)
        outputs, attn, attn_lang = vs_gats(
            node_num, features, spatial_feat, word2vec,
            [roi_labels])  # !NOTE: it is important to set [roi_labels]
        pg_outputs = pgception(pose_normalized, pose_to_obj_offset)
        # action_score = nn.Sigmoid()(outputs+pg_outputs)
        # action_score = action_score.cpu().detach().numpy()
        det_outputs = nn.Sigmoid()(outputs + pg_outputs)
        det_outputs = det_outputs.cpu().detach().numpy()

        # show result
        # import ipdb; ipdb.set_trace()
        if args.dataset == 'hico':
            image = Image.open(
                os.path.join('datasets/hico/images/test2015',
                             img)).convert('RGB')
            image_temp = image.copy()
            gt_img = vis_img(image,
                             det_boxes,
                             roi_labels,
                             roi_scores,
                             edge_labels.cpu().numpy(),
                             score_thresh=0.5)
            det_img = vis_img(image_temp,
                              det_boxes,
                              roi_labels,
                              roi_scores,
                              det_outputs,
                              score_thresh=0.5)
        if args.dataset == 'vcoco':
            image = Image.open(
                os.path.join(data_const.original_image_dir, 'val2014',
                             img)).convert('RGB')
            image_temp = image.copy()
            gt_img = vis_img_vcoco(image,
                                   det_boxes,
                                   roi_labels,
                                   roi_scores,
                                   edge_labels.cpu().numpy(),
                                   score_thresh=0.1)
            det_img = vis_img_vcoco(image_temp,
                                    det_boxes,
                                    roi_labels,
                                    roi_scores,
                                    det_outputs,
                                    score_thresh=0.5)

        # det_img.save('/home/birl/ml_dl_projects/bigjun/hoi/VS_GATs/inference_imgs/original_imgs'+'/'+img)
        det_img.save(save_path + '/' + img.split("/")[-1])
Example #4
0
def main(args):
    # use GPU if available else revert to CPU
    device = torch.device(
        'cuda' if torch.cuda.is_available() and args.gpu else 'cpu')
    print("Testing on", device)

    # Load checkpoint and set up model
    try:
        # load checkpoint
        checkpoint = torch.load(args.main_pretrained, map_location=device)
        print('vsgats Checkpoint loaded!')
        pg_checkpoint = torch.load(args.pretrained, map_location=device)

        # set up model and initialize it with uploaded checkpoint
        if not args.exp_ver:
            args.exp_ver = args.pretrained.split(
                "/")[-2] + "_" + args.pretrained.split("/")[-1].split("_")[-2]
            # import ipdb; ipdb.set_trace()
        data_const = VcocoConstants(feat_type=checkpoint['feat_type'],
                                    exp_ver=args.exp_ver)
        vs_gats = AGRNN(feat_type=checkpoint['feat_type'],
                        bias=checkpoint['bias'],
                        bn=checkpoint['bn'],
                        dropout=checkpoint['dropout'],
                        multi_attn=checkpoint['multi_head'],
                        layer=checkpoint['layers'],
                        diff_edge=checkpoint['diff_edge'])  #2 )
        vs_gats.load_state_dict(checkpoint['state_dict'])
        vs_gats.to(device)
        vs_gats.eval()

        print(pg_checkpoint['o_c_l'], pg_checkpoint['lr'],
              pg_checkpoint['dropout'])
        # pgception = PGception(action_num=24, classifier_mod='cat', o_c_l=[64,64,128,128], last_h_c=256, bias=pg_checkpoint['bias'], drop=pg_checkpoint['dropout'], bn=pg_checkpoint['bn'])
        if 'b_l' in pg_checkpoint.keys():
            print(pg_checkpoint['b_l'])
            pgception = PGception(
                action_num=pg_checkpoint['a_n'],
                layers=1,
                classifier_mod=pg_checkpoint['classifier_mod'],
                o_c_l=pg_checkpoint['o_c_l'],
                last_h_c=pg_checkpoint['last_h_c'],
                bias=pg_checkpoint['bias'],
                drop=pg_checkpoint['dropout'],
                bn=pg_checkpoint['bn'],
                agg_first=pg_checkpoint['agg_first'],
                attn=pg_checkpoint['attn'],
                b_l=pg_checkpoint['b_l'])
        else:
            pgception = PGception(
                action_num=pg_checkpoint['a_n'],
                layers=1,
                classifier_mod=pg_checkpoint['classifier_mod'],
                o_c_l=pg_checkpoint['o_c_l'],
                last_h_c=pg_checkpoint['last_h_c'],
                bias=pg_checkpoint['bias'],
                drop=pg_checkpoint['dropout'],
                bn=pg_checkpoint['bn'],
                agg_first=pg_checkpoint['agg_first'],
                attn=pg_checkpoint['attn'])
        pgception.load_state_dict(pg_checkpoint['state_dict'])
        pgception.to(device)
        pgception.eval()
        print('Constructed model successfully!')
    except Exception as e:
        print('Failed to load checkpoint or construct model!', e)
        sys.exit(1)

    io.mkdir_if_not_exists(data_const.result_dir, recursive=True)
    det_save_file = os.path.join(data_const.result_dir,
                                 'detection_results.pkl')
    if not os.path.isfile(det_save_file) or args.rewrite:
        test_dataset = VcocoDataset(data_const=data_const,
                                    subset='vcoco_test',
                                    pg_only=False)
        test_dataloader = DataLoader(dataset=test_dataset,
                                     batch_size=1,
                                     shuffle=False,
                                     collate_fn=collate_fn)
        # save detection result
        det_data_list = []
        # for global_id in tqdm(test_list):
        # import ipdb; ipdb.set_trace()
        for data in tqdm(test_dataloader):
            global_id = data['global_id'][0]
            det_boxes = data['det_boxes'][0]
            roi_scores = data['roi_scores'][0]
            roi_labels = data['roi_labels'][0]
            node_num = data['node_num']
            features = data['features']
            spatial_feat = data['spatial_feat']
            word2vec = data['word2vec']
            pose_normalized = data["pose_to_human"]
            pose_to_obj_offset = data["pose_to_obj_offset"]

            # referencing
            features, spatial_feat, word2vec = features.to(
                device), spatial_feat.to(device), word2vec.to(device)
            pose_to_obj_offset, pose_normalized = pose_to_obj_offset.to(
                device), pose_normalized.to(device)

            outputs, attn, attn_lang = vs_gats(
                node_num, features, spatial_feat, word2vec,
                [roi_labels])  # !NOTE: it is important to set [roi_labels]

            if 'b_l' in checkpoint.keys() and 4 in checkpoint['b_l']:
                pg_outputs1, pg_outputs2 = pgception(pose_normalized,
                                                     pose_to_obj_offset)
                action_scores = nn.Sigmoid()(outputs + pg_outputs1 +
                                             pg_outputs2)

            else:
                pg_outputs = pgception(pose_normalized, pose_to_obj_offset)
                action_scores = nn.Sigmoid()(outputs + pg_outputs)

            action_scores = action_scores.cpu().detach().numpy()

            h_idxs = np.where(roi_labels == 1)[0]
            # import ipdb; ipdb.set_trace()
            for h_idx in h_idxs:
                for i_idx in range(node_num[0]):
                    if i_idx == h_idx:
                        continue
                    # save hoi results in single image
                    single_result = {}
                    single_result['image_id'] = global_id
                    single_result['person_box'] = det_boxes[h_idx, :]
                    if h_idx > i_idx:
                        edge_idx = h_idx * (node_num[0] - 1) + i_idx
                    else:
                        edge_idx = h_idx * (node_num[0] - 1) + i_idx - 1
                    try:
                        score = roi_scores[h_idx] * roi_scores[
                            i_idx] * action_scores[edge_idx]
                        # score = score + pg_score
                    except Exception as e:
                        import ipdb
                        ipdb.set_trace()
                    for action in vcoco_metadata.action_class_with_object:
                        if action == 'none':
                            continue
                        action_idx = vcoco_metadata.action_with_obj_index[
                            action]
                        single_action_score = score[action_idx]
                        if action == 'cut_with' or action == 'eat_with' or action == 'hit_with':
                            action = action.split('_')[0]
                            role_name = 'instr'
                        else:
                            role_name = vcoco_metadata.action_roles[action][1]
                        action_role_key = '{}_{}'.format(action, role_name)
                        single_result[action_role_key] = np.append(
                            det_boxes[i_idx, :], single_action_score)

                    det_data_list.append(single_result)
        # save all detected results
        pickle.dump(det_data_list, open(det_save_file, 'wb'))
    # evaluate
    vcocoeval = VCOCOeval(
        os.path.join(data_const.original_data_dir,
                     'data/vcoco/vcoco_test.json'),
        os.path.join(data_const.original_data_dir,
                     'data/instances_vcoco_all_2014.json'),
        os.path.join(data_const.original_data_dir,
                     'data/splits/vcoco_test.ids'))
    vcocoeval._do_eval(data_const, det_save_file, ovr_thresh=0.5)
Example #5
0
def run_model(args, data_const):
    # prepare data
    train_dataset = VcocoDataset(data_const=data_const, subset="vcoco_trainval", pg_only=False)
    val_dataset = VcocoDataset(data_const=data_const, subset="vcoco_val", pg_only=False)
    dataset = {'train': train_dataset, 'val': val_dataset}

    train_dataloader = DataLoader(dataset=dataset['train'], batch_size=args.batch_size, shuffle=True, pin_memory=True, collate_fn=collate_fn)
    val_dataloader = DataLoader(dataset=dataset['val'], batch_size=args.batch_size, shuffle=True, pin_memory=True, collate_fn=collate_fn)
    dataloader = {'train': train_dataloader, 'val': val_dataloader}
    print("Preparing data done!!!")

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'training on {device}...')
    
    # load checkpoint
    checkpoint = torch.load(args.main_pretrained, map_location=device)
    print('vsgats Checkpoint loaded!')
    # set up model and initialize it with uploaded checkpoint
    vs_gats = AGRNN(feat_type=checkpoint['feat_type'], bias=checkpoint['bias'], bn=checkpoint['bn'], dropout=checkpoint['dropout'], multi_attn=checkpoint['multi_head'], layer=checkpoint['layers'], diff_edge=checkpoint['diff_edge']) #2 )
    vs_gats.load_state_dict(checkpoint['state_dict'])
    for param in vs_gats.parameters():
        param.requires_grad = False
    vs_gats.to(device)
    vs_gats.eval()

    # [64,64,128,128], [128,256,256,256]
    print(args.b_l, args.o_c_l)
    model = PGception(action_num=args.a_n, layers=args.n_layers, classifier_mod=args.c_m, o_c_l=args.o_c_l, b_l=args.b_l,
                      last_h_c=args.last_h_c, bias=args.bias, drop=args.d_p, bn=args.bn, agg_first=args.agg_first, attn=args.attn)
    # load pretrained model
    if args.pretrained:
        print(f"loading pretrained model {args.pretrained}")
        checkpoints = torch.load(args.pretrained, map_location=device)
        model.load_state_dict(checkpoints['state_dict'])

    model.to(device)
    # # build optimizer && criterion  
    if args.optim == 'sgd':
        optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=0)
    elif args.optim == 'adam':
        optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=0)
    else:
        optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=0, amsgrad=True)
    criterion = nn.BCEWithLogitsLoss()
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.scheduler_step, gamma=0.1) #the scheduler divides the lr by 10 every 400 epochs
    # set visualization and create folder to save checkpoints
    writer = SummaryWriter(log_dir=args.log_dir + '/' + args.exp_ver)
    io.mkdir_if_not_exists(os.path.join(args.save_dir, args.exp_ver), recursive=True)

    # load pretrained model of HICO_DET dataset
    if args.hico:
        print(f"loading pretrained model of HICO_DET dataset {args.hico}")
        hico_checkpoints = torch.load(args.hico, map_location=device)
        # import ipdb; ipdb.set_trace()
        hico_dict = hico_checkpoints['state_dict']
        model_dict = model.state_dict()
        need_dict = {k:v for k,v in hico_dict.items() if k not in ['classifier.4.weight', 'classifier.4.bias']}
        model_dict.update(need_dict)
        model.load_state_dict(model_dict)
        # fine-turn
        if args.fine_turn:
            last_layer_params = list(map(id, model.classifier[4].parameters()))
            base_params = filter(lambda p: id(p) not in last_layer_params, model.parameters())
            optimizer = optim.Adam([{'params': base_params, 'lr': args.lr}, {'params': model.classifier[4].parameters(), 'lr': 3e-5}], weight_decay=0)
    print(optimizer)

    # start training
    for epoch in range(args.start_epoch, args.epoch):
        # each epoch has a training and validation step
        epoch_loss = 0
        # for phase in ['train', 'val']:
        for phase in ['train']:
            start_time = time.time()
            running_loss = 0
            # all_edge = 0
            idx = 0
            # import ipdb; ipdb.set_trace()
            for data in tqdm(dataloader[phase]):
                roi_labels = data['roi_labels']
                node_num = data['node_num']
                features = data['features']
                spatial_feat = data['spatial_feat']
                word2vec = data['word2vec']
                edge_labels = data['edge_labels']
                pose_normalized = data["pose_to_human"]
                pose_to_obj_offset = data["pose_to_obj_offset"]

                features, spatial_feat, word2vec, edge_labels = features.to(device), spatial_feat.to(device), word2vec.to(device), edge_labels.to(device)
                pose_to_obj_offset, pose_normalized, edge_labels =  pose_to_obj_offset.to(device), pose_normalized.to(device), edge_labels.to(device)

                if phase == "train":
                    model.train()
                    model.zero_grad()
                    # for the part-body graph, under testing
                    if 4 in args.b_l:
                        outputs1, outputs2 = model(pose_normalized, pose_to_obj_offset)
                        outputs = outputs1 + outputs2 + vs_gats(node_num, features, spatial_feat, word2vec, roi_labels, validation=True) 
                    else:
                        outputs = vs_gats(node_num, features, spatial_feat, word2vec, roi_labels, validation=True) + model(pose_normalized, pose_to_obj_offset)

                    loss = criterion(outputs, edge_labels)
                    loss.backward()
                    optimizer.step()
                else:
                    model.eval()
                    with torch.no_grad():
                        if 4 in args.b_l:
                            outputs1, outputs2 = model(pose_normalized, pose_to_obj_offset)
                            outputs = outputs1 + outputs2 + vs_gats(node_num, features, spatial_feat, word2vec, roi_labels, validation=True) 
                        else:
                            outputs = vs_gats(node_num, features, spatial_feat, word2vec, roi_labels, validation=True) + model(pose_normalized, pose_to_obj_offset)

                        loss = criterion(outputs, edge_labels)

                running_loss += loss.item() * edge_labels.shape[0]

            epoch_loss = running_loss / len(dataset[phase])
            # if phase == 'train':
            #     train_loss = epoch_loss 
            # else:
            #     writer.add_scalars('trainval_loss_epoch', {'train': train_loss, 'val': epoch_loss}, epoch)
            writer.add_scalars('trainval_loss_epoch', {'train': epoch_loss}, epoch)
            # print data
            if epoch==0 or (epoch % args.print_every) == 9:
                end_time = time.time()
                print("[{}] Epoch: {}/{} Loss: {} Execution time: {}".format(\
                        phase, epoch+1, args.epoch, epoch_loss, (end_time-start_time)))   
        if args.scheduler_step:   
            scheduler.step()
        # save model 
        if epoch % args.save_every == (args.save_every - 1) and epoch >= (300-1):
            checkpoint = { 
                            'lr': args.lr,
                           'b_s': args.batch_size,
                          'bias': args.bias, 
                            'bn': args.bn, 
                       'dropout': args.d_p,
                         'o_c_l': args.o_c_l,
                           'b_l': args.b_l,
                      'last_h_c': args.last_h_c,
                           'a_n': args.a_n,
                'classifier_mod': args.c_m,
                      'n_layers': args.n_layers,
                     'agg_first': args.agg_first,
                          'attn': args.attn,
                    'state_dict': model.state_dict()
            }
            save_name = "checkpoint_" + str(epoch+1) + '_epoch.pth'
            torch.save(checkpoint, os.path.join(args.save_dir, args.exp_ver, save_name))

    writer.close()
    print('Finishing training!')
Example #6
0
def main(args):

    # use GPU if available else revert to CPU
    device = torch.device(
        'cuda' if torch.cuda.is_available() and args.gpu else 'cpu')
    print("Testing on", device)

    # Load checkpoint and set up model
    try:
        # load checkpoint
        checkpoint = torch.load(args.pretrained, map_location=device)
        print('Checkpoint loaded!')

        # set up model and initialize it with uploaded checkpoint
        # ipdb.set_trace()
        if not args.exp_ver:
            args.exp_ver = args.pretrained.split(
                "/")[-3] + "_" + args.pretrained.split("/")[-1].split("_")[-2]
        data_const = VcocoConstants(feat_type=checkpoint['feat_type'],
                                    exp_ver=args.exp_ver)
        model = AGRNN(feat_type=checkpoint['feat_type'],
                      bias=checkpoint['bias'],
                      bn=checkpoint['bn'],
                      dropout=checkpoint['dropout'],
                      multi_attn=checkpoint['multi_head'],
                      layer=checkpoint['layers'],
                      diff_edge=checkpoint['diff_edge'])  #2 )
        # ipdb.set_trace()
        model.load_state_dict(checkpoint['state_dict'])
        model.to(device)
        model.eval()
        print('Constructed model successfully!')
    except Exception as e:
        print('Failed to load checkpoint or construct model!', e)
        sys.exit(1)

    io.mkdir_if_not_exists(data_const.result_dir)
    det_save_file = os.path.join(data_const.result_dir,
                                 'detection_results.pkl')
    if not os.path.isfile(det_save_file) or args.rewrite:
        test_dataset = VcocoDataset(data_const=data_const, subset='vcoco_test')
        test_dataloader = DataLoader(dataset=test_dataset,
                                     batch_size=1,
                                     shuffle=False,
                                     collate_fn=collate_fn)
        # save detection result
        det_data_list = []
        # for global_id in tqdm(test_list):
        for data in tqdm(test_dataloader):
            train_data = data
            global_id = train_data['global_id'][0]
            det_boxes = train_data['det_boxes'][0]
            roi_scores = train_data['roi_scores'][0]
            roi_labels = train_data['roi_labels'][0]
            node_num = train_data['node_num']
            features = train_data['features']
            spatial_feat = train_data['spatial_feat']
            word2vec = train_data['word2vec']

            # referencing
            features, spatial_feat, word2vec = features.to(
                device), spatial_feat.to(device), word2vec.to(device)
            outputs, attn, attn_lang = model(
                node_num, features, spatial_feat, word2vec,
                [roi_labels])  # !NOTE: it is important to set [roi_labels]

            action_scores = nn.Sigmoid()(outputs)
            action_scores = action_scores.cpu().detach().numpy()
            attn = attn.cpu().detach().numpy()
            attn_lang = attn_lang.cpu().detach().numpy()

            h_idxs = np.where(roi_labels == 1)[0]
            # import ipdb; ipdb.set_trace()
            for h_idx in h_idxs:
                for i_idx in range(node_num[0]):
                    if i_idx == h_idx:
                        continue
                    # save hoi results in single image
                    single_result = {}
                    single_result['image_id'] = global_id
                    single_result['person_box'] = det_boxes[h_idx, :]
                    if h_idx > i_idx:
                        edge_idx = h_idx * (node_num[0] - 1) + i_idx
                    else:
                        edge_idx = h_idx * (node_num[0] - 1) + i_idx - 1
                    # score = roi_scores[h_idx] * roi_scores[i_idx] * action_score[edge_idx] * (attn[h_idx][i_idx-1]+attn_lang[h_idx][i_idx-1])
                    try:
                        score = roi_scores[h_idx] * roi_scores[
                            i_idx] * action_scores[edge_idx]
                    except Exception as e:
                        import ipdb
                        ipdb.set_trace()
                    for action in vcoco_metadata.action_class_with_object:
                        if action == 'none':
                            continue
                        action_idx = vcoco_metadata.action_with_obj_index[
                            action]
                        single_action_score = score[action_idx]
                        if action == 'cut_with' or action == 'eat_with' or action == 'hit_with':
                            action = action.split('_')[0]
                            role_name = 'instr'
                        else:
                            role_name = vcoco_metadata.action_roles[action][1]
                        action_role_key = '{}_{}'.format(action, role_name)
                        single_result[action_role_key] = np.append(
                            det_boxes[i_idx, :], single_action_score)

                    det_data_list.append(single_result)
        # save all detected results
        pickle.dump(det_data_list, open(det_save_file, 'wb'))
    # evaluate
    vcocoeval = VCOCOeval(
        os.path.join(data_const.original_data_dir,
                     'data/vcoco/vcoco_test.json'),
        os.path.join(data_const.original_data_dir,
                     'data/instances_vcoco_all_2014.json'),
        os.path.join(data_const.original_data_dir,
                     'data/splits/vcoco_test.ids'))
    vcocoeval._do_eval(data_const, det_save_file, ovr_thresh=0.5)
Example #7
0
def run_model(config):
    # prepare data
    global args
    global data_const
    train_dataset = VcocoDataset(data_const=data_const, subset="vcoco_train")
    val_dataset = VcocoDataset(data_const=data_const, subset="vcoco_val")
    dataset = {'train': train_dataset, 'val': val_dataset}

    train_dataloader = DataLoader(dataset=dataset['train'],
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  pin_memory=True,
                                  collate_fn=collate_fn)
    val_dataloader = DataLoader(dataset=dataset['val'],
                                batch_size=args.batch_size,
                                shuffle=True,
                                pin_memory=True,
                                collate_fn=collate_fn)
    dataloader = {'train': train_dataloader, 'val': val_dataloader}
    print("Preparing data done!!!")

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'training on {device}...')

    model = PGception(
        action_num=args.a_n,
        classifier_mod=args.c_m,
        o_c_l=[config["b0_h"], config["b1_h"], config["b2_h"], config["b3_h"]],
        last_h_c=config["last_h_c"],
        bias=args.bias,
        drop=config['d_p'],
        bn=args.bn)
    # load pretrained model
    if args.pretrained:
        print(f"loading pretrained model {args.pretrained}")
        checkpoints = torch.load(args.pretrained, map_location=device)
        model.load_state_dict(checkpoints['state_dict'])
    model.to(device)
    # # build optimizer && criterion
    if args.optim == 'sgd':
        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              momentum=0.9,
                              weight_decay=0)
    else:
        optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=0)
    criterion = nn.BCEWithLogitsLoss()
    scheduler = optim.lr_scheduler.StepLR(
        optimizer, step_size=400,
        gamma=1 / 3)  #the scheduler divides the lr by 10 every 150 epochs
    # set visualization and create folder to save checkpoints
    writer = SummaryWriter(log_dir=args.log_dir + '/' + args.exp_ver)
    # import ipdb; ipdb.set_trace()
    # io.mkdir_if_not_exists(os.path.join(args.save_dir, args.exp_ver), recursive=True)
    # start training
    for epoch in range(args.start_epoch, args.epoch):
        # each epoch has a training and validation step
        epoch_loss = 0
        for phase in ['train', 'val']:
            start_time = time.time()
            running_loss = 0
            # all_edge = 0
            idx = 0
            # import ipdb; ipdb.set_trace()
            for data in tqdm(dataloader[phase]):
                pose_feat = data["pose_feat"]
                labels = data['pose_labels']
                pose_feat, labels = pose_feat.to(device), labels.to(device)
                if phase == "train":
                    model.train()
                    model.zero_grad()
                    outputs = model(pose_feat)
                    loss = criterion(outputs, labels)
                    loss.backward()
                    optimizer.step()
                else:
                    model.eval()
                    with torch.no_grad():
                        outputs = model(pose_feat)
                        loss = criterion(outputs, labels)

                running_loss += loss.item() * labels.shape[0]

            epoch_loss = running_loss / len(dataset[phase])
            if phase == 'train':
                train_loss = epoch_loss
            else:
                tune.track.log(train_loss=train_loss, val_loss=epoch_loss)
            #     writer.add_scalars('trainval_loss_epoch', {'train': train_loss, 'val': epoch_loss}, epoch)
            # # print data
            # if (epoch % args.print_every) == 0:
            #     end_time = time.time()
            #     print("[{}] Epoch: {}/{} Loss: {} Execution time: {}".format(\
            #             phase, epoch+1, args.epoch, epoch_loss, (end_time-start_time)))

        # scheduler.step()
        # save model epoch_loss<0.29 or
        # if epoch % args.save_every == (args.save_every - 1) and epoch >= (5-1):
        #     checkpoint = {
        #                     'lr': args.lr,
        #                    'b_s': args.batch_size,
        #                   'bias': args.bias,
        #                     'bn': args.bn,
        #                'dropout': args.d_p,
        #             'state_dict': model.state_dict()
        #     }
        #     save_name = "checkpoint_" + str(epoch+1) + '_epoch.pth'
        #     torch.save(checkpoint, os.path.join(args.save_dir, args.exp_ver, save_name))

    writer.close()
    print('Finishing training!')