np.random.seed(1234) tf.set_random_seed(1234) # read the train args from checkpoint param_path = args['checkpoint'].replace('.tmodel', '_params.json') with open(param_path, 'r') as file_id: saved_args = json.load(file_id) saved_args.update(args) args = saved_args args['preload_feats'] = False # no supervision is needed args['supervise_attention'] = False args['batch_size'] = min(args['batch_size'], 10) # adjust for complex models support.pretty_print_dict(args) # Data files root = args['data_root'] if args['use_refer']: # use refer module file_name = 'imdb/imdb_%s_refermrf_att.npy' % args['test_split'] else: file_name = 'imdb/imdb_%s.npy' % args['test_split'] imdb_path_val = os.path.join(root, file_name) # assemblers for question and caption programs question_assembler = Assembler(args['prog_vocab_path']) caption_assembler = Assembler(args['prog_vocab_path']) assemblers = {'ques': question_assembler, 'cap': caption_assembler}
def read_command_line(): title = 'Train explicit coreference resolution visual dialog model' parser = argparse.ArgumentParser(description=title) #------------------------------------------------------------------------- # data input settings parser.add_argument('--dataset', default='visdial_v0.9_tiny', help='Visdial dataset type') parser.add_argument('--data_root', default='data/', help='Root to the data') parser.add_argument('--feature_path', default='data/resnet_res5c/', help='Path to the image features') parser.add_argument('--text_vocab_path', default='', help='Path to the vocabulary for text') parser.add_argument('--prog_vocab_path', default='', help='Path to the vocabulary for programs') parser.add_argument('--snapshot_path', default='checkpoints/', help='Path to save checkpoints') #-------------------------------------------------------------------------- # specify encoder/decoder parser.add_argument('--model', default='nmn-cap-prog-only', help='Name of the model, will be changed later') parser.add_argument('--generator', default='ques', help='Name of the generator to use (ques | memory)') parser.add_argument('--decoder', default='gen', help='Name of the decoder to use (gen | disc)') parser.add_argument('--preload_features', default=False, type=bool, help='Preload visual features on RAM') #------------------------------------------------------------------------- # model hyperparameters parser.add_argument('--h_feat', default=14, type=int, help='Height of visual conv feature') parser.add_argument('--w_feat', default=14, type=int, help='Width of visual conv feature') parser.add_argument('--d_feat', default=2048, type=int, help='Size of visual conv feature') parser.add_argument('--text_embed_size', default=300, type=int, help='Size of embedding for text') parser.add_argument('--map_size', default=1024, type=int, help='Size of the final mapping') parser.add_argument('--prog_embed_size', default=300, type=int, help='Size of embedding for program tokens') parser.add_argument('--lstm_size', default=1000, type=int, help='Size of hidden state in LSTM') parser.add_argument('--enc_dropout', default=True, type=bool, help='Dropout in encoder') parser.add_argument('--dec_dropout', default=True, type=bool, help='Dropout in decoder') parser.add_argument('--num_layers', default=2, type=int, help='Number of layers in LSTM') parser.add_argument( '--max_enc_len', default=24, type=int, help='Maximum encoding length for sentences (ques|cap)') parser.add_argument('--max_dec_len', default=14, type=int, help='Maximum decoding length for programs (ques|cap)') parser.add_argument('--dec_sampling', default=False, type=bool, help='Sample while decoding programs vs argmax') #--------------------------------------------------------------------------- parser.add_argument('--use_refer', dest='use_refer', action='store_true', help='Flag to use Refer for coreference resolution') parser.set_defaults(use_refer=False) parser.add_argument('--use_fact', dest='use_fact', action='store_true', help='Flag to use the fact in coreference pool') parser.set_defaults(use_fact=False) parser.add_argument('--supervise_attention', dest='supervise_attention', action='store_true', help='Flag to supervise attention for the modules') parser.set_defaults(supervise_attention=False) parser.add_argument('--amalgam_text_feats', dest='amalgam_text_feats', action='store_true', help='Flag to amalgamate text features') parser.set_defaults(amalgam_text_feats=False) parser.add_argument('--no_cap_alignment', dest='cap_alignment', action='store_false', help='Use the auxiliary caption alignment loss') parser.set_defaults(cap_alignment=True) #------------------------------------------------------------------------- # optimization params parser.add_argument( '--batch_size', default=20, type=int, help='Training batch size (adjust based on GPU memory)') parser.add_argument('--learning_rate', default=1e-3, type=float, help='Learning rate for training') parser.add_argument('--dropout', default=0.5, type=float, help='Dropout') parser.add_argument('--num_epochs', default=20, type=int, help='Maximum number of epochs to run training') parser.add_argument('--gpu_id', type=int, default=0, help='GPU id to use for training, -1 for CPU') #------------------------------------------------------------------------- try: parsed_args = vars(parser.parse_args()) except (IOError) as msg: parser.error(str(msg)) # set the cuda environment variable for the gpu to use gpu_id = '' if parsed_args['gpu_id'] < 0 else str(parsed_args['gpu_id']) print('Using GPU id: %s' % gpu_id) os.environ['CUDA_VISIBLE_DEVICES'] = gpu_id # pretty print arguments and return support.pretty_print_dict(parsed_args) return parsed_args