def train(args): logging.info("Create train_loader and val_loader.........") train_loader_kwargs = { 'question_pt': args.train_question_pt, 'scene_pt': args.train_scene_pt, 'vocab_json': args.vocab_json, 'batch_size': args.batch_size, 'ratio': args.ratio, 'shuffle': True } val_loader_kwargs = { 'question_pt': args.val_question_pt, 'scene_pt': args.val_scene_pt, 'vocab_json': args.vocab_json, 'batch_size': args.batch_size, 'shuffle': False } train_loader = ClevrDataLoader(**train_loader_kwargs) val_loader = ClevrDataLoader(**val_loader_kwargs) logging.info("Create model.........") device = 'cuda' if torch.cuda.is_available() else 'cpu' model_kwargs = { k:v for k,v in vars(args).items() if k in { 'dim_v', 'dim_pre_v', 'num_edge_cat', 'num_class', 'num_attribute', } } model_kwargs_tosave = copy.deepcopy(model_kwargs) model_kwargs['vocab'] = train_loader.vocab model = XNMNet(**model_kwargs).to(device) logging.info(model) optimizer = optim.Adam(model.parameters(), args.lr, weight_decay=args.l2reg) scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=[int(1/args.ratio)], gamma=0.1) criterion = nn.CrossEntropyLoss().to(device) logging.info("Start training........") tic = time.time() iter_count = 0 for epoch in range(args.num_epoch): for i, batch in enumerate(train_loader.generator()): iter_count += 1 progress = epoch+i/len(train_loader) answers, questions, *batch_input = \ [todevice(x, device) for x in batch] logits, others = model(*batch_input) loss = criterion(logits, answers) optimizer.zero_grad() loss.backward() optimizer.step() if (i+1) % (len(train_loader) // 10) == 0: logging.info("Progress %.3f loss = %.3f" % (progress, loss.item())) scheduler.step() if (epoch+1) % 1 == 0: valid_acc = validate(model, val_loader, device) logging.info('\n ~~~~~~ Valid Accuracy: %.4f ~~~~~~~\n' % valid_acc) save_checkpoint(epoch, model, optimizer, model_kwargs_tosave, os.path.join(args.save_dir, 'model.pt')) logging.info(' >>>>>> save to %s <<<<<<' % (args.save_dir))
def train(): train_loader = VQADataLoader(**train_loader_kwargs) device = 'cuda' if torch.cuda.is_available() else 'cpu' model_kwargs.update({'vocab': train_loader.vocab,'device': device}) val_loader = VQADataLoader(**val_loader_kwargs) model = XNMNet(**model_kwargs).to(device) train_loader.glove_matrix = torch.FloatTensor(train_loader.glove_matrix).to(device) with torch.no_grad(): model.token_embedding.weight.set_(train_loader.glove_matrix) ################################################################ parameters = [p for p in model.parameters() if p.requires_grad] optimizer = optim.Adam(parameters, lr, weight_decay=0) for epoch in range(num_epoch): model.train() i = 0 for batch in tqdm(train_loader, total=len(train_loader)): progress = epoch + i / len(train_loader) coco_ids, answers, *batch_input = [todevice(x, device) for x in batch] logits, others = model(*batch_input) ##################### loss ##################### nll = -nn.functional.log_softmax(logits, dim=1) loss = (nll * answers / 10).sum(dim=1).mean() ################################################# optimizer.zero_grad() loss.backward() nn.utils.clip_grad_value_(parameters, clip_value=0.5) optimizer.step() if (i + 1) % (len(train_loader) // 50) == 0: logging.info("Progress %.3f ce_loss = %.3f" % (progress, loss.item())) i+=1 train_acc,train_loss = validate(model, train_loader, device,withLossFlag = True,func = nn.functional) logging.info('\n ~~~~~~ Epoch: %.4f ~~~~~~~\n' % epoch) logging.info('\n ~~~~~~ Train Accuracy: %.4f ~~~~~~~\n' % train_acc) logging.info('\n ~~~~~~ Train Loss: %.4f ~~~~~~~\n' % train_loss) valid_acc,valid_loss = validate(model, val_loader, device,withLossFlag = True,func = nn.functional) logging.info('\n ~~~~~~ Valid Accuracy: %.4f ~~~~~~~\n' % valid_acc) logging.info('\n ~~~~~~ Valid Loss: %.4f ~~~~~~~\n' % valid_loss)
args.val_scene_pt = os.path.join(args.input_dir, args.val_scene_pt) device = 'cuda' val_loader_kwargs = { 'question_pt': args.val_question_pt, 'scene_pt': args.val_scene_pt, 'vocab_json': args.vocab_json, 'batch_size': 128, 'shuffle': False } val_loader = ClevrDataLoader(**val_loader_kwargs) loaded = torch.load(args.ckpt, map_location={'cuda:0': 'cpu'}) model_kwargs = loaded['model_kwargs'] model_kwargs.update({'vocab': val_loader.vocab}) model = XNMNet(**model_kwargs).to(device) model.load_state_dict(loaded['state_dict']) num_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print(' ~~~~~~~~~~ num parameters: %d ~~~~~~~~~~~~~' % num_parameters) if args.program == 'gt': print('validate with **ground truth** program') val_acc, val_details = validate(model, val_loader, device, detail=True) elif args.program == 'david': print('validate with **david predicted** program') val_acc, val_details = validate_with_david_generated_program( model, val_loader, device, args.pretrained_dir) print("Validate acc: %.4f" % val_acc) print(json.dumps(val_details, indent=2))
device = 'cuda' loaded = torch.load(args.ckpt, map_location={'cuda:0': 'cpu'}) model_kwargs = loaded['model_kwargs'] if args.mode == 'val': val_loader_kwargs = { 'question_pt': args.val_question_pt, 'vocab_json': args.vocab_json, 'feature_h5': args.val_feature_h5, 'batch_size': 128, 'spatial': model_kwargs['spatial'], 'num_workers': 2, 'shuffle': False } val_loader = VQADataLoader(**val_loader_kwargs) model_kwargs.update({'vocab': val_loader.vocab, 'device': device}) model = XNMNet(**model_kwargs).to(device) model.load_state_dict(loaded['state_dict']) valid_acc = validate(model, val_loader, device) print('valid acc: %.4f' % valid_acc) elif args.mode == 'test': assert args.output_file and os.path.exists(args.test_question_json) test_loader_kwargs = { 'question_pt': args.test_question_pt, 'vocab_json': args.vocab_json, 'feature_h5': args.test_feature_h5, 'batch_size': 128, 'spatial': model_kwargs['spatial'], 'num_workers': 2, 'shuffle': False } test_loader = VQADataLoader(**test_loader_kwargs)
def train(args): logging.info("Create train_loader and val_loader.........") train_loader_kwargs = { 'question_pt': args.train_question_pt, 'vocab_json': args.vocab_json, 'feature_h5': args.feature_h5, 'batch_size': args.batch_size, 'spatial': args.spatial, 'num_workers': 2, 'shuffle': True } train_loader = VQADataLoader(**train_loader_kwargs) if args.val: val_loader_kwargs = { 'question_pt': args.val_question_pt, 'vocab_json': args.vocab_json, 'feature_h5': args.feature_h5, 'batch_size': args.batch_size, 'spatial': args.spatial, 'num_workers': 2, 'shuffle': False } val_loader = VQADataLoader(**val_loader_kwargs) logging.info("Create model.........") device = 'cuda' if torch.cuda.is_available() else 'cpu' model_kwargs = { 'vocab': train_loader.vocab, 'dim_v': args.dim_v, 'dim_word': args.dim_word, 'dim_hidden': args.dim_hidden, 'dim_vision': args.dim_vision, 'dim_edge': args.dim_edge, 'cls_fc_dim': args.cls_fc_dim, 'dropout_prob': args.dropout, 'T_ctrl': args.T_ctrl, 'glimpses': args.glimpses, 'stack_len': args.stack_len, 'device': device, 'spatial': args.spatial, 'use_gumbel': args.module_prob_use_gumbel == 1, 'use_validity': args.module_prob_use_validity == 1, } model_kwargs_tosave = { k: v for k, v in model_kwargs.items() if k != 'vocab' } model = XNMNet(**model_kwargs).to(device) logging.info(model) logging.info('load glove vectors') train_loader.glove_matrix = torch.FloatTensor( train_loader.glove_matrix).to(device) model.token_embedding.weight.data.set_(train_loader.glove_matrix) ################################################################ parameters = [p for p in model.parameters() if p.requires_grad] optimizer = optim.Adam(parameters, args.lr, weight_decay=0) start_epoch = 0 if args.restore: print("Restore checkpoint and optimizer...") ckpt = os.path.join(args.save_dir, 'model.pt') ckpt = torch.load(ckpt, map_location={'cuda:0': 'cpu'}) start_epoch = ckpt['epoch'] + 1 model.load_state_dict(ckpt['state_dict']) optimizer.load_state_dict(ckpt['optimizer']) scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.5**(1 / args.lr_halflife)) logging.info("Start training........") for epoch in range(start_epoch, args.num_epoch): model.train() for i, batch in enumerate(train_loader): progress = epoch + i / len(train_loader) coco_ids, answers, *batch_input = [ todevice(x, device) for x in batch ] logits, others = model(*batch_input) ##################### loss ##################### nll = -nn.functional.log_softmax(logits, dim=1) loss = (nll * answers / 10).sum(dim=1).mean() ################################################# scheduler.step() optimizer.zero_grad() loss.backward() nn.utils.clip_grad_value_(parameters, clip_value=0.5) optimizer.step() if (i + 1) % (len(train_loader) // 50) == 0: logging.info("Progress %.3f ce_loss = %.3f" % (progress, loss.item())) save_checkpoint(epoch, model, optimizer, model_kwargs_tosave, os.path.join(args.save_dir, 'model.pt')) logging.info(' >>>>>> save to %s <<<<<<' % (args.save_dir)) if args.val: valid_acc = validate(model, val_loader, device) logging.info('\n ~~~~~~ Valid Accuracy: %.4f ~~~~~~~\n' % valid_acc)