def main(): args = parse_args() mp.set_start_method('spawn') # Using spawn is decided. _logger = log.get_logger(__name__, args) _logger.info(print_args(args)) loaders = [] file_list = os.listdir(args.train_file) random.shuffle(file_list) for i in range(args.worker): loader = data_loader.DataLoader(args.train_file, args.dict_file, separate_conj_stmt=args.direction, binary=args.binary, part_no=i, part_total=args.worker, file_list=file_list, norename=args.norename, filter_abelian=args.fabelian, compatible=args.compatible) loaders.append(loader) loader.start_reader() cuda_test = torch.cuda.is_available() cuda_tensor = torch.randn(10).cuda() net, mid_net, loss_fn = create_models(args, loaders[0], allow_resume=True) # Use fake modules to replace the real ones net = FakeModule(net) if mid_net is not None: mid_net = FakeModule(mid_net) for i in range(len(loss_fn)): loss_fn[i] = FakeModule(loss_fn[i]) opt = get_opt(net, mid_net, loss_fn, args) inqueues = [] outqueues = [] plist = [] for i in range(args.worker): recv_p, send_p = Pipe(False) recv_p2, send_p2 = Pipe(False) inqueues.append(send_p) outqueues.append(recv_p2) plist.append( Process(target=worker, args=(recv_p, send_p2, loaders[i], args, i))) plist[-1].start() _logger.warning('Training begins') train(inqueues, outqueues, net, mid_net, loss_fn, opt, loaders, args, _logger) loader.destruct() for p in plist: p.terminate() for loader in loaders: loader.destruct() _logger.warning('Training ends')
def main(): parser = argparse.ArgumentParser(description='The ultimate tester') parser.add_argument('--model', type=str, help='The model file name used for testing', required=True) parser.add_argument('--model_path', type=str, help='The path to model folder', default='../models') parser.add_argument('--log', type=str, help='Path to log file', required=True) parser.add_argument('--data', type=str, help='Path to testing set folder', default='../data/hol_data/test') parser.add_argument('--worker', type=int, help='Number of workers', default=4) parser.add_argument('--max_pair', type=int, help='Change max_pair settings') parser.add_argument('--compatible', action='store_true', help='Use compatible mode to run.') parser.add_argument('--dict_file', type=str, help='Replace dict') settings = parser.parse_args() mp.set_start_method('spawn') # Using spawn is decided. _logger = log.get_logger(__name__, settings) _logger.info('Test program parameters') _logger.info(print_args(settings)) model_path = os.path.join(settings.model_path, settings.model) net, mid_net, loss_fn, test_loader, args = load_model( model_path, settings.data, settings.compatible) args.test_file = settings.data if settings.dict_file is not None: args.dict_file = settings.dict_file if settings.max_pair is not None: args.max_pair = settings.max_pair args.compatible = settings.compatible inqueues = [] outqueues = [] if settings.worker is not None: args.worker = settings.worker plist = [] for i in range(args.worker): recv_p, send_p = Pipe(False) recv_p2, send_p2 = Pipe(False) inqueues.append(send_p) outqueues.append(recv_p2) plist.append( Process(target=worker, args=(recv_p, send_p2, test_loader, args, i))) plist[-1].start() test_loader = None _logger.info('Model parameters') _logger.info(print_args(args)) valid_start = time.time() data = {} data['args'] = args data['net'] = net.state_dict() data['fix_net'] = False if mid_net is not None: data['mid_net'] = mid_net.state_dict() data['loss_fn'] = [] for loss in loss_fn: data['loss_fn'].append(loss.state_dict()) data['test'] = True for i in range(args.worker): inqueues[i].send(data) result_correct = 0 result_total = 0 _logger.warning('Test start!') for i in range(args.worker): data = outqueues[i].recv() result_correct += data['correct'] result_total += data['total'] result_ = result_correct / result_total _logger.warning('Validation complete! Time lapse: %.3f, Test acc: %.5f' % (time.time() - valid_start, result_)) for p in plist: p.terminate() time.sleep(5)
def create_models(args, loader, allow_resume=False): net = model.GraphNet(loader.dict_size, args.nFeats, args.nSteps, args.block, args.module_depth, args.bias, args.short_cut, args.direction, args.loss, args.binary, no_step_supervision=args.no_step_supervision, tied_weight=args.tied_weight, compatible=args.compatible).cuda() mid_net = None if args.loss in ('mixmax', 'mixmean'): mid_net = model.FullyConnectedNet(args.nFeats, args.nFeats // 2, bias=args.bias).cuda() loss_fn = [] for i in range(args.loss_step): if args.loss == 'mulloss': loss_fn.append( loss.MultiplyLoss(args.nFeats, cond_short_cut=args.cond_short_cut).cuda()) elif args.loss == 'condloss': loss_fn.append( loss.CondLoss(args.nFeats * 2, args.nFeats, layer_list=args.loss_layers, dropout=args.dropout, cond_short_cut=args.cond_short_cut).cuda()) elif args.loss in ('concat', 'concat_em_uc'): if args.uncondition or args.add_conj: loss_fn.append( loss.ClassifyLoss(args.nFeats, args.nFeats // 2, layer_list=args.loss_layers, dropout=args.dropout, bias=args.compatible).cuda()) else: loss_fn.append( loss.ClassifyLoss(args.nFeats * 2, args.nFeats, layer_list=args.loss_layers, dropout=args.dropout, bias=args.compatible).cuda()) elif args.loss in ('mixmax', 'mixmean'): loss_fn.append( loss.UCSimLoss(args.nFeats, args.nFeats // 2, layer_list=args.loss_layers, dropout=args.dropout).cuda()) elif args.loss == 'pair': loss_fn.append( loss.ClassifyLoss(args.nFeats // 2, args.nFeats // 4, layer_list=args.loss_layers, dropout=args.dropout).cuda()) elif args.loss == 'em': loss_fn.append( loss.ClassifyLoss(args.nFeats, args.nFeats // 2, layer_list=args.loss_layers, dropout=args.dropout).cuda()) else: assert False, 'Wrong --loss option!' if args.resume is not None and allow_resume: data = torch.load(args.resume) _logger = log.get_logger('Create Model', args) _logger.warning('Load Model!') _logger.info('Previous training info:') _logger.info('Epoch: %d Current iter: %d Total iter: %d', data['aux']['epoch'], data['aux']['cur_iter'], data['aux']['total_iter']) _logger.warning('Previous training args:') _logger.info(print_args(data['args'])) net.load_state_dict(data['net']['state_dict']) if mid_net is not None: mid_net.load_state_dict(data['mid_net']['state_dict']) if not args.resume_only_net: for i in range(len(loss_fn)): loss_fn[i].load_state_dict(data['loss_fn'][i]['state_dict']) return net, mid_net, loss_fn
import sys import argparse import torch from model_utils import print_args data = torch.load(sys.argv[1]) print('Previous training info:') print( 'Epoch: %d (start from 0) Current iter: %d Total iter: %d' % (data['aux']['epoch'], data['aux']['cur_iter'], data['aux']['total_iter'])) print(print_args(data['args']))