Exemplo n.º 1
0
def train(args):
    logging.info("Create train_loader and val_loader.........")
    train_loader_kwargs = {
        'question_pt': args.train_question_pt,
        'scene_pt': args.train_scene_pt,
        'vocab_json': args.vocab_json,
        'batch_size': args.batch_size,
        'ratio': args.ratio,
        'shuffle': True
    }
    val_loader_kwargs = {
        'question_pt': args.val_question_pt,
        'scene_pt': args.val_scene_pt,
        'vocab_json': args.vocab_json,
        'batch_size': args.batch_size,
        'shuffle': False
    }
    
    train_loader = ClevrDataLoader(**train_loader_kwargs)
    val_loader = ClevrDataLoader(**val_loader_kwargs)

    logging.info("Create model.........")
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model_kwargs = { k:v for k,v in vars(args).items() if k in {
        'dim_v', 'dim_pre_v', 'num_edge_cat', 'num_class', 'num_attribute',
        } }
    model_kwargs_tosave = copy.deepcopy(model_kwargs) 
    model_kwargs['vocab'] = train_loader.vocab
    model = XNMNet(**model_kwargs).to(device)
    logging.info(model)

    optimizer = optim.Adam(model.parameters(), args.lr, weight_decay=args.l2reg)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=[int(1/args.ratio)], gamma=0.1)
    criterion = nn.CrossEntropyLoss().to(device)
    logging.info("Start training........")
    tic = time.time()
    iter_count = 0
    for epoch in range(args.num_epoch):
        for i, batch in enumerate(train_loader.generator()):
            iter_count += 1
            progress = epoch+i/len(train_loader)
            answers, questions, *batch_input = \
                    [todevice(x, device) for x in batch]

            logits, others = model(*batch_input)
            loss = criterion(logits, answers)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if (i+1) % (len(train_loader) // 10) == 0:
                logging.info("Progress %.3f  loss = %.3f" % (progress, loss.item()))
        scheduler.step()
        if (epoch+1) % 1 == 0:
            valid_acc = validate(model, val_loader, device)
            logging.info('\n ~~~~~~ Valid Accuracy: %.4f ~~~~~~~\n' % valid_acc)

            save_checkpoint(epoch, model, optimizer, model_kwargs_tosave, os.path.join(args.save_dir, 'model.pt')) 
            logging.info(' >>>>>> save to %s <<<<<<' % (args.save_dir))
Exemplo n.º 2
0
def train():
    train_loader = VQADataLoader(**train_loader_kwargs)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model_kwargs.update({'vocab': train_loader.vocab,'device': device})
    val_loader = VQADataLoader(**val_loader_kwargs)
    model = XNMNet(**model_kwargs).to(device)
    train_loader.glove_matrix = torch.FloatTensor(train_loader.glove_matrix).to(device)
    with torch.no_grad():
        model.token_embedding.weight.set_(train_loader.glove_matrix)
    ################################################################
    parameters = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.Adam(parameters, lr, weight_decay=0)
    for epoch in range(num_epoch):
        model.train()
        i = 0
        for batch in tqdm(train_loader, total=len(train_loader)):
            progress = epoch + i / len(train_loader)
            coco_ids, answers, *batch_input = [todevice(x, device) for x in batch]
            logits, others = model(*batch_input)
            ##################### loss #####################
            nll = -nn.functional.log_softmax(logits, dim=1)
            loss = (nll * answers / 10).sum(dim=1).mean()
            #################################################
            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_value_(parameters, clip_value=0.5)
            optimizer.step()
            if (i + 1) % (len(train_loader) // 50) == 0:
                logging.info("Progress %.3f  ce_loss = %.3f" % (progress, loss.item()))
            i+=1
        train_acc,train_loss = validate(model, train_loader, device,withLossFlag = True,func = nn.functional)
        logging.info('\n ~~~~~~ Epoch: %.4f ~~~~~~~\n' % epoch)
        logging.info('\n ~~~~~~ Train Accuracy: %.4f ~~~~~~~\n' % train_acc)
        logging.info('\n ~~~~~~ Train Loss: %.4f ~~~~~~~\n' % train_loss)
        valid_acc,valid_loss = validate(model, val_loader, device,withLossFlag = True,func = nn.functional)
        logging.info('\n ~~~~~~ Valid Accuracy: %.4f ~~~~~~~\n' % valid_acc)
        logging.info('\n ~~~~~~ Valid Loss: %.4f ~~~~~~~\n' % valid_loss)
Exemplo n.º 3
0
    args.val_scene_pt = os.path.join(args.input_dir, args.val_scene_pt)

    device = 'cuda'
    val_loader_kwargs = {
        'question_pt': args.val_question_pt,
        'scene_pt': args.val_scene_pt,
        'vocab_json': args.vocab_json,
        'batch_size': 128,
        'shuffle': False
    }
    val_loader = ClevrDataLoader(**val_loader_kwargs)

    loaded = torch.load(args.ckpt, map_location={'cuda:0': 'cpu'})
    model_kwargs = loaded['model_kwargs']
    model_kwargs.update({'vocab': val_loader.vocab})
    model = XNMNet(**model_kwargs).to(device)
    model.load_state_dict(loaded['state_dict'])

    num_parameters = sum(p.numel() for p in model.parameters()
                         if p.requires_grad)
    print(' ~~~~~~~~~~ num parameters: %d ~~~~~~~~~~~~~' % num_parameters)

    if args.program == 'gt':
        print('validate with **ground truth** program')
        val_acc, val_details = validate(model, val_loader, device, detail=True)
    elif args.program == 'david':
        print('validate with **david predicted** program')
        val_acc, val_details = validate_with_david_generated_program(
            model, val_loader, device, args.pretrained_dir)
    print("Validate acc: %.4f" % val_acc)
    print(json.dumps(val_details, indent=2))
Exemplo n.º 4
0
 device = 'cuda'
 loaded = torch.load(args.ckpt, map_location={'cuda:0': 'cpu'})
 model_kwargs = loaded['model_kwargs']
 if args.mode == 'val':
     val_loader_kwargs = {
         'question_pt': args.val_question_pt,
         'vocab_json': args.vocab_json,
         'feature_h5': args.val_feature_h5,
         'batch_size': 128,
         'spatial': model_kwargs['spatial'],
         'num_workers': 2,
         'shuffle': False
     }
     val_loader = VQADataLoader(**val_loader_kwargs)
     model_kwargs.update({'vocab': val_loader.vocab, 'device': device})
     model = XNMNet(**model_kwargs).to(device)
     model.load_state_dict(loaded['state_dict'])
     valid_acc = validate(model, val_loader, device)
     print('valid acc: %.4f' % valid_acc)
 elif args.mode == 'test':
     assert args.output_file and os.path.exists(args.test_question_json)
     test_loader_kwargs = {
         'question_pt': args.test_question_pt,
         'vocab_json': args.vocab_json,
         'feature_h5': args.test_feature_h5,
         'batch_size': 128,
         'spatial': model_kwargs['spatial'],
         'num_workers': 2,
         'shuffle': False
     }
     test_loader = VQADataLoader(**test_loader_kwargs)
Exemplo n.º 5
0
def train(args):
    logging.info("Create train_loader and val_loader.........")
    train_loader_kwargs = {
        'question_pt': args.train_question_pt,
        'vocab_json': args.vocab_json,
        'feature_h5': args.feature_h5,
        'batch_size': args.batch_size,
        'spatial': args.spatial,
        'num_workers': 2,
        'shuffle': True
    }
    train_loader = VQADataLoader(**train_loader_kwargs)
    if args.val:
        val_loader_kwargs = {
            'question_pt': args.val_question_pt,
            'vocab_json': args.vocab_json,
            'feature_h5': args.feature_h5,
            'batch_size': args.batch_size,
            'spatial': args.spatial,
            'num_workers': 2,
            'shuffle': False
        }
        val_loader = VQADataLoader(**val_loader_kwargs)

    logging.info("Create model.........")
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model_kwargs = {
        'vocab': train_loader.vocab,
        'dim_v': args.dim_v,
        'dim_word': args.dim_word,
        'dim_hidden': args.dim_hidden,
        'dim_vision': args.dim_vision,
        'dim_edge': args.dim_edge,
        'cls_fc_dim': args.cls_fc_dim,
        'dropout_prob': args.dropout,
        'T_ctrl': args.T_ctrl,
        'glimpses': args.glimpses,
        'stack_len': args.stack_len,
        'device': device,
        'spatial': args.spatial,
        'use_gumbel': args.module_prob_use_gumbel == 1,
        'use_validity': args.module_prob_use_validity == 1,
    }
    model_kwargs_tosave = {
        k: v
        for k, v in model_kwargs.items() if k != 'vocab'
    }
    model = XNMNet(**model_kwargs).to(device)
    logging.info(model)
    logging.info('load glove vectors')
    train_loader.glove_matrix = torch.FloatTensor(
        train_loader.glove_matrix).to(device)
    model.token_embedding.weight.data.set_(train_loader.glove_matrix)
    ################################################################

    parameters = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.Adam(parameters, args.lr, weight_decay=0)

    start_epoch = 0
    if args.restore:
        print("Restore checkpoint and optimizer...")
        ckpt = os.path.join(args.save_dir, 'model.pt')
        ckpt = torch.load(ckpt, map_location={'cuda:0': 'cpu'})
        start_epoch = ckpt['epoch'] + 1
        model.load_state_dict(ckpt['state_dict'])
        optimizer.load_state_dict(ckpt['optimizer'])
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer,
                                                 0.5**(1 / args.lr_halflife))

    logging.info("Start training........")
    for epoch in range(start_epoch, args.num_epoch):
        model.train()
        for i, batch in enumerate(train_loader):
            progress = epoch + i / len(train_loader)
            coco_ids, answers, *batch_input = [
                todevice(x, device) for x in batch
            ]
            logits, others = model(*batch_input)
            ##################### loss #####################
            nll = -nn.functional.log_softmax(logits, dim=1)
            loss = (nll * answers / 10).sum(dim=1).mean()
            #################################################
            scheduler.step()
            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_value_(parameters, clip_value=0.5)
            optimizer.step()
            if (i + 1) % (len(train_loader) // 50) == 0:
                logging.info("Progress %.3f  ce_loss = %.3f" %
                             (progress, loss.item()))
        save_checkpoint(epoch, model, optimizer, model_kwargs_tosave,
                        os.path.join(args.save_dir, 'model.pt'))
        logging.info(' >>>>>> save to %s <<<<<<' % (args.save_dir))
        if args.val:
            valid_acc = validate(model, val_loader, device)
            logging.info('\n ~~~~~~ Valid Accuracy: %.4f ~~~~~~~\n' %
                         valid_acc)