# print(dictionaries[0]) print('Initializing CLEVR dataset...') # Build the model n_words = len(dictionaries[0]) + 1 n_choices = len(dictionaries[1]) print('n_words = {}, n_choices = {}'.format(n_words, n_choices)) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = Stack_NMN(args.max_stack_len, args.max_time_stamps, args.n_modules, n_choices, args.n_nodes, n_words, args.embed_size, args.lstm_hid_dim, args.input_feat_dim, args.map_dim, args.mlp_hid_dim, args.mem_dim, args.kb_dim, args.kernel_size, device).to(device) # load checkpoint model.load_state_dict(torch.load(args.ckpt)) model = nn.DataParallel(model) clevr_test = ClevrDataset(args.clevr_dir, split='test', features_dir=args.features_dir, dictionaries=dictionaries) test_set = DataLoader(clevr_test, batch_size=args.batch_size, num_workers=1)
def main(): parser = argparse.ArgumentParser(description='Stack-NMN') parser.add_argument('--embed_size', type=int, help='embedding dim. of question words', default=300) parser.add_argument('--lstm_hid_dim', type=int, help='hidden dim. of LSTM', default=256) parser.add_argument('--input_feat_dim', type=int, help='feat dim. of image features', default=1024) parser.add_argument('--map_dim', type=int, help='hidden dim. size of intermediate attention maps', default=512) parser.add_argument('--text_param_dim', type=int, help='hidden dim. of textual param.', default=512) parser.add_argument('--mlp_hid_dim', type=int, help='hidden dim. of mlp', default=512) parser.add_argument('--mem_dim', type=int, help='hidden dim. of mem.', default=512) parser.add_argument('--kb_dim', type=int, help='hidden dim. of conv features.', default=512) parser.add_argument('--max_stack_len', type=int, help='max. length of stack', default=8) parser.add_argument('--max_time_stamps', type=int, help='max. number of time-stamps for modules', default=9) parser.add_argument('--clevr_dir', type=str, help='Directory of CLEVR dataset', required=True) parser.add_argument('--batch_size', type=int, default=64) parser.add_argument('--n_epochs', type=int, default=50) parser.add_argument('--n_modules', type=int, default=9) parser.add_argument('--kernel_size', type=int, default=3) parser.add_argument('--model_dir', type=str, required=True) parser.add_argument('--features_dir', type=str, default='data') parser.add_argument('--clevr_feature_dir', type=str, default='/u/username/data/clevr_features/') parser.add_argument('--copy_data', action='store_true') parser.add_argument('--lr', type=float, default=1e-4) parser.add_argument('--reg_coeff', type=float, default=1e-1) parser.add_argument('--ckpt', type=str, default='') parser.add_argument('--resume', action='store_true') # use only on slurm parser.add_argument('--optim', type=str, default='adam') parser.add_argument('--use_half', action='store_true') # use only on slurm # SGDR hyper-params parser.add_argument('--T0', type=int, default=1) parser.add_argument('--Tmult', type=int, default=2) parser.add_argument('--eta_min', type=float, default=1e-5) args = parser.parse_args() print(args) ''' with open('data/dic.pkl', 'rb') as f1: dic = pickle.load(f1) n_words = len(dic['word_dic']) + 1 n_choices = len(dic['answer_dic']) print('n_words = {}, n_choices = {}'.format(n_words, n_choices)) ''' print('Building word dictionaries from all the words in the dataset...') dictionaries = utils_clevr_humans.build_dictionaries(args.clevr_dir) print('Building word dictionary completed!') print('Initializing CLEVR dataset...') # Build the model n_words = len(dictionaries[0])+1 n_choices = len(dictionaries[1]) print('n_words = {}, n_choices = {}'.format(n_words, n_choices)) if not os.path.exists(args.model_dir): os.makedirs(args.model_dir) writer = SummaryWriter(log_dir=args.model_dir) if args.copy_data: start_time = time.time() copytree(args.clevr_feature_dir, os.path.join(os.path.expandvars('$SLURM_TMPDIR'),'clevr_features/')) # copytree('/u/username/data/clevr_features/', '/Tmp/username/clevr_features/') # args.features_dir = '/Tmp/username/clevr_features/' args.features_dir = os.path.join(os.path.expandvars('$SLURM_TMPDIR'),'clevr_features/') print('data copy finished in {} sec.'.format(time.time() - start_time)) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # print('device = {}'.format(device)) model = Stack_NMN(args.max_stack_len, args.max_time_stamps, args.n_modules, n_choices, 3, n_words, args.embed_size, args.lstm_hid_dim, args.input_feat_dim, args.map_dim, args.mlp_hid_dim, args.mem_dim, args.kb_dim, args.kernel_size, False, device).to(device) # model = nn.DataParallel(model) criterion = nn.CrossEntropyLoss() # optimizer_1 = optim.Adam(map(lambda p:p[1], filter(lambda p:p[1].requires_grad and 'weight_mlp' not in p[0], model.named_parameters())), lr=args.lr, weight_decay=0e-3) # optimizer_2 = optim.Adam(map(lambda p:p[1], filter(lambda p:p[1].requires_grad and 'weight_mlp' in p[0], model.named_parameters())), lr=0e-8, weight_decay=0e-2) if args.optim == 'adam': optimizer = optim.Adam(map(lambda p:p[1], filter(lambda p:p[1].requires_grad, model.named_parameters())), lr=args.lr) elif args.optim == 'asgd': optimizer = optim.ASGD(map(lambda p:p[1], filter(lambda p:p[1].requires_grad, model.named_parameters())), lr=args.lr) elif args.optim == 'sgd': optimizer = optim.SGD(map(lambda p:p[1], filter(lambda p:p[1].requires_grad, model.named_parameters())), lr=args.lr) elif args.optim == 'adamax': optimizer = optim.Adamax(map(lambda p:p[1], filter(lambda p:p[1].requires_grad, model.named_parameters())), lr=args.lr) elif args.optim == 'adagrad': optimizer = optim.Adagrad(map(lambda p:p[1], filter(lambda p:p[1].requires_grad, model.named_parameters())), lr=args.lr) elif args.optim == 'adadelta': optimizer = optim.Adadelta(map(lambda p:p[1], filter(lambda p:p[1].requires_grad, model.named_parameters())), lr=args.lr) elif args.optim == 'sgdr': optimizer = optim.SGD(map(lambda p:p[1], filter(lambda p:p[1].requires_grad, model.named_parameters())), lr=args.lr) clevr_val = ClevrDataset(args.clevr_dir, split='val', features_dir=args.features_dir, dictionaries=dictionaries) val_set = DataLoader(clevr_val, batch_size=args.batch_size, num_workers=4, collate_fn=collate_data) start_epoch = 0 if len(args.ckpt)>0: model.load_state_dict({k:v for k,v in torch.load(args.ckpt).items() if 'embed' not in k}, strict=False) # start_epoch = int(args.ckpt.split('_')[-1].split('.')[0]) # print('start_epoch = {}'.format(start_epoch)) prev_embed = torch.load(args.ckpt)['embed.weight'] model.embed.weight.data[:prev_embed.size(0), :].copy_(prev_embed) # print(model.embed.weight.data) if args.resume: model_ckpts = list(filter(lambda x:'ckpt_epoch' in x, os.listdir(args.model_dir))) if len(model_ckpts)>0: model_ckpts_epoch_ids = [int(filename.split('_')[-1].split('.')[0]) for filename in model_ckpts] start_epoch = max(model_ckpts_epoch_ids) latest_ckpt_file = os.path.join(args.model_dir, 'ckpt_epoch_{}.model'.format(start_epoch)) model.load_state_dict(torch.load(latest_ckpt_file)) print('Loaded ckpt file {}'.format(latest_ckpt_file)) print('start_epoch = {}'.format(start_epoch)) val(model, criterion, optimizer, val_set, args.batch_size, device, writer, args.n_epochs) writer.close()
def main(): parser = argparse.ArgumentParser(description='Stack-NMN') parser.add_argument('--embed_size', type=int, help='embedding dim. of question words', default=300) parser.add_argument('--lstm_hid_dim', type=int, help='hidden dim. of LSTM', default=256) parser.add_argument('--input_feat_dim', type=int, help='feat dim. of image features', default=1024) parser.add_argument('--map_dim', type=int, help='hidden dim. size of intermediate attention maps', default=512) # parser.add_argument('--text_param_dim', type=int, help='hidden dim. of textual param.', default=512) parser.add_argument('--mlp_hid_dim', type=int, help='hidden dim. of mlp', default=512) parser.add_argument('--mem_dim', type=int, help='hidden dim. of mem.', default=512) parser.add_argument('--kb_dim', type=int, help='hidden dim. of conv features.', default=512) parser.add_argument('--max_stack_len', type=int, help='max. length of stack', default=8) parser.add_argument('--max_time_stamps', type=int, help='max. number of time-stamps for modules', default=9) parser.add_argument('--clevr_dir', type=str, help='Directory of CLEVR dataset', required=True) parser.add_argument('--batch_size', type=int, default=32) parser.add_argument('--n_epochs', type=int, default=50) parser.add_argument('--n_modules', type=int, default=7) # includes 1 NoOp module parser.add_argument('--n_nodes', type=int, default=2) # TODO: change/tune later parser.add_argument('--kernel_size', type=int, default=3) parser.add_argument('--model_dir', type=str, required=True) parser.add_argument('--features_dir', type=str, default='data') parser.add_argument('--clevr_feature_dir', type=str, default='/u/username/data/clevr_features/') parser.add_argument('--copy_data', action='store_true') parser.add_argument('--lr', type=float, default=1e-4) parser.add_argument('--temperature', type=float, default=0.2) parser.add_argument('--reg_coeff', type=float, default=1e-2) parser.add_argument('--ckpt', type=str, default='') parser.add_argument('--resume', action='store_true') # use only on slurm parser.add_argument('--reg_coeff_op_loss', type=float, default=1e-1) # DARTS args parser.add_argument('--unrolled', action='store_true', default=False, help='use one-step unrolled validation loss') parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding') parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding') parser.add_argument('--momentum', type=float, default=0.9, help='momentum') args = parser.parse_args() print('Building word dictionaries from all the words in the dataset...') dictionaries = utils.build_dictionaries(args.clevr_dir) print('Building word dictionary completed!') print('Initializing CLEVR dataset...') # Build the model n_words = len(dictionaries[0]) + 1 n_choices = len(dictionaries[1]) print('n_words = {}, n_choices = {}'.format(n_words, n_choices)) if not os.path.exists(args.model_dir): os.makedirs(args.model_dir) writer = SummaryWriter(log_dir=args.model_dir) if args.copy_data: start_time = time.time() copytree( args.clevr_feature_dir, os.path.join(os.path.expandvars('$SLURM_TMPDIR'), 'clevr_features/')) # copytree('/u/username/data/clevr_features/', '/Tmp/username/clevr_features/') # args.features_dir = '/Tmp/username/clevr_features/' args.features_dir = os.path.join(os.path.expandvars('$SLURM_TMPDIR'), 'clevr_features/') print('data copy finished in {} sec.'.format(time.time() - start_time)) # TODO: remove this later # args.features_dir = '/Tmp/username/clevr_features/' device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # print('device = {}'.format(device)) model = Stack_NMN(args.max_stack_len, args.max_time_stamps, args.n_modules, n_choices, args.n_nodes, args.temperature, n_words, args.embed_size, args.lstm_hid_dim, args.input_feat_dim, args.map_dim, args.mlp_hid_dim, args.mem_dim, args.kb_dim, args.kernel_size, device).to(device) # model = Stack_NMN(args.max_stack_len, args.max_time_stamps, args.n_modules, n_choices, n_words, args.embed_size, args.lstm_hid_dim, args.input_feat_dim, args.map_dim, args.text_param_dim, args.mlp_hid_dim, args.kernel_size, device).to(device) start_epoch = 0 if len(args.ckpt) > 0: model.load_state_dict(torch.load(args.ckpt)) start_epoch = int(args.ckpt.split('_')[-1].split('.')[0]) print('start_epoch = {}'.format(start_epoch)) if args.resume: model_ckpts = list( filter(lambda x: 'ckpt_epoch' in x, os.listdir(args.model_dir))) if len(model_ckpts) > 0: model_ckpts_epoch_ids = [ int(filename.split('_')[-1].split('.')[0]) for filename in model_ckpts ] start_epoch = max(model_ckpts_epoch_ids) latest_ckpt_file = os.path.join( args.model_dir, 'ckpt_epoch_{}.model'.format(start_epoch)) model.load_state_dict(torch.load(latest_ckpt_file)) print('Loaded ckpt file {}'.format(latest_ckpt_file)) print('start_epoch = {}'.format(start_epoch)) model = nn.DataParallel(model) criterion = nn.CrossEntropyLoss() # optimizer_1 = optim.Adam(map(lambda p:p[1], filter(lambda p:p[1].requires_grad and 'weight_mlp' not in p[0], model.named_parameters())), lr=args.lr, weight_decay=0e-3) # optimizer_2 = optim.Adam(map(lambda p:p[1], filter(lambda p:p[1].requires_grad and 'weight_mlp' in p[0], model.named_parameters())), lr=0e-8, weight_decay=0e-2) optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.module.network_parameters()), lr=args.lr) # optimizer = optim.Adam(filter(lambda p:p.requires_grad, model.network_parameters()), lr=args.lr) # optimizer = optim.Adam(map(lambda x:x[1], filter(lambda p:p[1].requires_grad, model.named_parameters())), lr=args.lr) architect = Architect(model, device, args) clevr_train = ClevrDataset(args.clevr_dir, split='train', features_dir=args.features_dir, dictionaries=dictionaries) clevr_val = ClevrDataset(args.clevr_dir, split='val', features_dir=args.features_dir, dictionaries=dictionaries) train_set = DataLoader(clevr_train, batch_size=args.batch_size, num_workers=0, collate_fn=collate_data) val_set = DataLoader(clevr_val, batch_size=args.batch_size, num_workers=0, collate_fn=collate_data) val_set_architect = DataLoader(clevr_val, batch_size=args.batch_size, num_workers=0, collate_fn=collate_data, sampler=torch.utils.data.RandomSampler( list(range(len(clevr_val))))) for epoch_id in range(start_epoch, args.n_epochs): train(epoch_id, model, architect, criterion, optimizer, train_set, val_set_architect, args.batch_size, device, writer, args.n_epochs, args.lr, args.unrolled, args.reg_coeff) valid(epoch_id, model, criterion, optimizer, val_set, args.batch_size, device, writer, args.n_epochs) with open( '{}/ckpt_epoch_{}.model'.format(args.model_dir, str(epoch_id + 1)), 'wb') as f1: torch.save(model.module.state_dict(), f1) clevr_train.close() clevr_val.close() writer.close()