Ejemplo n.º 1
0
def create_model(args, dbinfo):
    """ Creates model """
    model = nn.Module()

    nfeat = args.ptn_widths[1][-1]

    model.ecc = graphnet.GraphNetwork(
        args.model_config, nfeat, [dbinfo['edge_feats']] + args.fnet_widths,
        args.fnet_orthoinit, args.fnet_llbias, args.fnet_bnidx,
        args.edge_mem_limit)

    model.ptn = pointnet.PointNet(args.ptn_widths[0],
                                  args.ptn_widths[1],
                                  args.ptn_widths_stn[0],
                                  args.ptn_widths_stn[1],
                                  dbinfo['node_feats'],
                                  args.ptn_nfeat_stn,
                                  prelast_do=args.ptn_prelast_do)

    print('Total number of parameters: {}'.format(
        sum([p.numel() for p in model.parameters()])))
    print(model)
    if args.cuda:
        model.cuda()
    return model
Ejemplo n.º 2
0
def resume(args, dbinfo):
    """ Loads model and optimizer state from a previous checkpoint. """
    print("=> loading checkpoint '{}'".format(args.resume))
    checkpoint = torch.load(args.resume)

    checkpoint[
        'args'].model_config = args.model_config  #to ensure compatibility with previous arguments convention
    #this should be removed once new models are uploaded
    args = checkpoint['args']
    model = create_model(
        checkpoint['args'],
        dbinfo)  #use original arguments, architecture can't change

    optimizer = create_optimizer(args, model)
    print('edge feature', dbinfo['edge_feats'])
    model.ecc = graphnet.GraphNetwork(
        args.model_config, 7, [dbinfo['edge_feats']] + args.fnet_widths,
        args.fnet_orthoinit, args.fnet_llbias, args.fnet_bnidx,
        args.edge_mem_limit)

    state = checkpoint['state_dict']
    state['ecc.1.bias'] = torch.tensor(np.random.rand(7, 1))
    state['ecc.1.weight'] = torch.tensor(np.random.rand(8, 352))

    model.load_state_dict(state)

    # print('see state dict')
    # print(checkpoint['state_dict'])

    # model.load_state_dict(state)

    if 'optimizer' in checkpoint:
        optimizer.load_state_dict(checkpoint['optimizer'])
    for group in optimizer.param_groups:
        group['initial_lr'] = args.lr
    args.start_epoch = checkpoint['epoch']
    try:
        stats = json.loads(
            open(os.path.join(os.path.dirname(args.resume),
                              'trainlog.txt')).read())
    except:
        stats = []
    return model, optimizer, stats