def create_model(args, dbinfo): """ Creates model """ model = nn.Module() nfeat = args.ptn_widths[1][-1] model.ecc = graphnet.GraphNetwork( args.model_config, nfeat, [dbinfo['edge_feats']] + args.fnet_widths, args.fnet_orthoinit, args.fnet_llbias, args.fnet_bnidx, args.edge_mem_limit) model.ptn = pointnet.PointNet(args.ptn_widths[0], args.ptn_widths[1], args.ptn_widths_stn[0], args.ptn_widths_stn[1], dbinfo['node_feats'], args.ptn_nfeat_stn, prelast_do=args.ptn_prelast_do) print('Total number of parameters: {}'.format( sum([p.numel() for p in model.parameters()]))) print(model) if args.cuda: model.cuda() return model
def resume(args, dbinfo): """ Loads model and optimizer state from a previous checkpoint. """ print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) checkpoint[ 'args'].model_config = args.model_config #to ensure compatibility with previous arguments convention #this should be removed once new models are uploaded args = checkpoint['args'] model = create_model( checkpoint['args'], dbinfo) #use original arguments, architecture can't change optimizer = create_optimizer(args, model) print('edge feature', dbinfo['edge_feats']) model.ecc = graphnet.GraphNetwork( args.model_config, 7, [dbinfo['edge_feats']] + args.fnet_widths, args.fnet_orthoinit, args.fnet_llbias, args.fnet_bnidx, args.edge_mem_limit) state = checkpoint['state_dict'] state['ecc.1.bias'] = torch.tensor(np.random.rand(7, 1)) state['ecc.1.weight'] = torch.tensor(np.random.rand(8, 352)) model.load_state_dict(state) # print('see state dict') # print(checkpoint['state_dict']) # model.load_state_dict(state) if 'optimizer' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) for group in optimizer.param_groups: group['initial_lr'] = args.lr args.start_epoch = checkpoint['epoch'] try: stats = json.loads( open(os.path.join(os.path.dirname(args.resume), 'trainlog.txt')).read()) except: stats = [] return model, optimizer, stats