def load(config, dict_name=None): """ Loading the dataloaders, net configs, optimiser and the head order""" """Loads data, net configs, optimiser and the head order Params: config: configuration for the training run dict_name: name of dictionary, in case a previous run is resumed Returns: [type] -- [description] """ dataloaders_head_A, mapping_assignment_dataloader, mapping_test_dataloader = segmentation_create_dataloaders( config) dataloaders_head_B = dataloaders_head_A net = archs.__dict__[config.arch](config) if config.restart and dict_name is not None: dict = torch.load(os.path.join(config.out_dir, dict_name), map_location=lambda storage, loc: storage) net.load_state_dict(dict["net"]) net.cuda() if not config.nocuda else None net = torch.nn.DataParallel(net) net.train() optimiser = get_opt(config.opt)(net.module.parameters(), lr=config.lr) if config.restart: optimiser.load_state_dict(dict["optimiser"]) heads = ["A", "B"] if hasattr(config, "head_B_first") and config.head_B_first: heads = ["B", "A"] return (dataloaders_head_A, dataloaders_head_B, mapping_assignment_dataloader, mapping_test_dataloader, net, optimiser, heads)
def train(): dataloaders_head_A, mapping_assignment_dataloader, mapping_test_dataloader = \ segmentation_create_dataloaders(config) dataloaders_head_B = dataloaders_head_A # unlike for clustering datasets net = archs.__dict__[config.arch](config) if config.restart: dict = torch.load(os.path.join(config.out_dir, dict_name), map_location=lambda storage, loc: storage) net.load_state_dict(dict["net"]) net.cuda() net = torch.nn.DataParallel(net) net.train() optimiser = get_opt(config.opt)(net.module.parameters(), lr=config.lr) if config.restart: optimiser.load_state_dict(dict["optimiser"]) heads = ["A", "B"] if hasattr(config, "head_B_first") and config.head_B_first: heads = ["B", "A"] # Results # ---------------------------------------------------------------------- if config.restart: next_epoch = config.last_epoch + 1 print("starting from epoch %d" % next_epoch) config.epoch_acc = config.epoch_acc[:next_epoch] # in case we overshot config.epoch_avg_subhead_acc = config.epoch_avg_subhead_acc[: next_epoch] config.epoch_stats = config.epoch_stats[:next_epoch] config.epoch_loss_head_A = config.epoch_loss_head_A[:(next_epoch - 1)] config.epoch_loss_no_lamb_head_A = config.epoch_loss_no_lamb_head_A[:( next_epoch - 1)] config.epoch_loss_head_B = config.epoch_loss_head_B[:(next_epoch - 1)] config.epoch_loss_no_lamb_head_B = config.epoch_loss_no_lamb_head_B[:( next_epoch - 1)] else: config.epoch_acc = [] config.epoch_avg_subhead_acc = [] config.epoch_stats = [] config.epoch_loss_head_A = [] config.epoch_loss_no_lamb_head_A = [] config.epoch_loss_head_B = [] config.epoch_loss_no_lamb_head_B = [] _ = segmentation_eval( config, net, mapping_assignment_dataloader=mapping_assignment_dataloader, mapping_test_dataloader=mapping_test_dataloader, sobel=(not config.no_sobel), using_IR=config.using_IR) print("Pre: time %s: \n %s" % (datetime.now(), nice(config.epoch_stats[-1]))) sys.stdout.flush() next_epoch = 1 fig, axarr = plt.subplots(6, sharex=False, figsize=(20, 20)) if not config.use_uncollapsed_loss: print("using condensed loss (default)") loss_fn = IID_segmentation_loss else: print("using uncollapsed loss!") loss_fn = IID_segmentation_loss_uncollapsed # Train # ------------------------------------------------------------------------ for e_i in xrange(next_epoch, config.num_epochs): print("Starting e_i: %d %s" % (e_i, datetime.now())) sys.stdout.flush() if e_i in config.lr_schedule: optimiser = update_lr(optimiser, lr_mult=config.lr_mult) for head_i in range(2): head = heads[head_i] if head == "A": dataloaders = dataloaders_head_A epoch_loss = config.epoch_loss_head_A epoch_loss_no_lamb = config.epoch_loss_no_lamb_head_A lamb = config.lamb_A elif head == "B": dataloaders = dataloaders_head_B epoch_loss = config.epoch_loss_head_B epoch_loss_no_lamb = config.epoch_loss_no_lamb_head_B lamb = config.lamb_B iterators = (d for d in dataloaders) b_i = 0 avg_loss = 0. # over heads and head_epochs (and subheads) avg_loss_no_lamb = 0. avg_loss_count = 0 for tup in itertools.izip(*iterators): net.module.zero_grad() if not config.no_sobel: pre_channels = config.in_channels - 1 else: pre_channels = config.in_channels all_img1 = torch.zeros(config.batch_sz, pre_channels, config.input_sz, config.input_sz).to( torch.float32).cuda() all_img2 = torch.zeros(config.batch_sz, pre_channels, config.input_sz, config.input_sz).to( torch.float32).cuda() all_affine2_to_1 = torch.zeros(config.batch_sz, 2, 3).to(torch.float32).cuda() all_mask_img1 = torch.zeros(config.batch_sz, config.input_sz, config.input_sz).to( torch.float32).cuda() curr_batch_sz = tup[0][0].shape[0] for d_i in xrange(config.num_dataloaders): img1, img2, affine2_to_1, mask_img1 = tup[d_i] assert (img1.shape[0] == curr_batch_sz) actual_batch_start = d_i * curr_batch_sz actual_batch_end = actual_batch_start + curr_batch_sz all_img1[ actual_batch_start:actual_batch_end, :, :, :] = img1 all_img2[ actual_batch_start:actual_batch_end, :, :, :] = img2 all_affine2_to_1[actual_batch_start: actual_batch_end, :, :] = affine2_to_1 all_mask_img1[ actual_batch_start:actual_batch_end, :, :] = mask_img1 if not (curr_batch_sz == config.dataloader_batch_sz) and (e_i == next_epoch): print("last batch sz %d" % curr_batch_sz) curr_total_batch_sz = curr_batch_sz * config.num_dataloaders # times 2 all_img1 = all_img1[:curr_total_batch_sz, :, :, :] all_img2 = all_img2[:curr_total_batch_sz, :, :, :] all_affine2_to_1 = all_affine2_to_1[:curr_total_batch_sz, :, :] all_mask_img1 = all_mask_img1[:curr_total_batch_sz, :, :] if (not config.no_sobel): all_img1 = sobel_process(all_img1, config.include_rgb, using_IR=config.using_IR) all_img2 = sobel_process(all_img2, config.include_rgb, using_IR=config.using_IR) x1_outs = net(all_img1, head=head) x2_outs = net(all_img2, head=head) avg_loss_batch = None # avg over the heads avg_loss_no_lamb_batch = None for i in xrange(config.num_subheads): loss, loss_no_lamb = loss_fn( x1_outs[i], x2_outs[i], all_affine2_to_1=all_affine2_to_1, all_mask_img1=all_mask_img1, lamb=lamb, half_T_side_dense=config.half_T_side_dense, half_T_side_sparse_min=config.half_T_side_sparse_min, half_T_side_sparse_max=config.half_T_side_sparse_max) if avg_loss_batch is None: avg_loss_batch = loss avg_loss_no_lamb_batch = loss_no_lamb else: avg_loss_batch += loss avg_loss_no_lamb_batch += loss_no_lamb avg_loss_batch /= config.num_subheads avg_loss_no_lamb_batch /= config.num_subheads if ((b_i % 100) == 0) or (e_i == next_epoch): print( "Model ind %d epoch %d head %s batch: %d avg loss %f avg loss no " "lamb %f " "time %s" % \ (config.model_ind, e_i, head, b_i, avg_loss_batch.item(), avg_loss_no_lamb_batch.item(), datetime.now())) sys.stdout.flush() if not np.isfinite(avg_loss_batch.item()): print("Loss is not finite... %s:" % str(avg_loss_batch)) exit(1) avg_loss += avg_loss_batch.item() avg_loss_no_lamb += avg_loss_no_lamb_batch.item() avg_loss_count += 1 avg_loss_batch.backward() optimiser.step() torch.cuda.empty_cache() b_i += 1 if b_i == 2 and config.test_code: break avg_loss = float(avg_loss / avg_loss_count) avg_loss_no_lamb = float(avg_loss_no_lamb / avg_loss_count) epoch_loss.append(avg_loss) epoch_loss_no_lamb.append(avg_loss_no_lamb) # Eval # ----------------------------------------------------------------------- is_best = segmentation_eval( config, net, mapping_assignment_dataloader=mapping_assignment_dataloader, mapping_test_dataloader=mapping_test_dataloader, sobel=(not config.no_sobel), using_IR=config.using_IR) print("Pre: time %s: \n %s" % (datetime.now(), nice(config.epoch_stats[-1]))) sys.stdout.flush() axarr[0].clear() axarr[0].plot(config.epoch_acc) axarr[0].set_title("acc (best), top: %f" % max(config.epoch_acc)) axarr[1].clear() axarr[1].plot(config.epoch_avg_subhead_acc) axarr[1].set_title("acc (avg), top: %f" % max(config.epoch_avg_subhead_acc)) axarr[2].clear() axarr[2].plot(config.epoch_loss_head_A) axarr[2].set_title("Loss head A") axarr[3].clear() axarr[3].plot(config.epoch_loss_no_lamb_head_A) axarr[3].set_title("Loss no lamb head A") axarr[4].clear() axarr[4].plot(config.epoch_loss_head_B) axarr[4].set_title("Loss head B") axarr[5].clear() axarr[5].plot(config.epoch_loss_no_lamb_head_B) axarr[5].set_title("Loss no lamb head B") fig.canvas.draw_idle() fig.savefig(os.path.join(config.out_dir, "plots.png")) if is_best or (e_i % config.save_freq == 0): net.module.cpu() save_dict = { "net": net.module.state_dict(), "optimiser": optimiser.state_dict() } if e_i % config.save_freq == 0: torch.save(save_dict, os.path.join(config.out_dir, "latest.pytorch")) config.last_epoch = e_i # for last saved version if is_best: torch.save(save_dict, os.path.join(config.out_dir, "best.pytorch")) with open(os.path.join(config.out_dir, "best_config.pickle"), 'wb') as outfile: pickle.dump(config, outfile) with open(os.path.join(config.out_dir, "best_config.txt"), "w") as text_file: text_file.write("%s" % config) net.module.cuda() with open(os.path.join(config.out_dir, "config.pickle"), 'wb') as outfile: pickle.dump(config, outfile) with open(os.path.join(config.out_dir, "config.txt"), "w") as text_file: text_file.write("%s" % config) if config.test_code: exit(0)
# Model ------------------------------------------------------------------------ dataloaders, mapping_assignment_dataloader, mapping_test_dataloader = \ cluster_create_dataloaders(config) net = archs.__dict__[config.arch](config) if config.restart: model_path = os.path.join(config.out_dir, net_name) net.load_state_dict( torch.load(model_path, map_location=lambda storage, loc: storage)) net.cuda() net = torch.nn.DataParallel(net) net.train() optimiser = get_opt(config.opt)(net.module.parameters(), lr=config.lr) if config.restart: optimiser.load_state_dict( torch.load(os.path.join(config.out_dir, opt_name))) # Results ---------------------------------------------------------------------- if config.restart: if not config.restart_from_best: next_epoch = config.last_epoch + 1 # corresponds to last saved model else: next_epoch = np.argmax(np.array(config.epoch_acc)) + 1 print("starting from epoch %d" % next_epoch) config.epoch_acc = config.epoch_acc[:next_epoch] # in case we overshot config.epoch_avg_subhead_acc = config.epoch_avg_subhead_acc[:next_epoch]
def train(render_count=-1): dataloaders_head_A, dataloaders_head_B, \ mapping_assignment_dataloader, mapping_test_dataloader = \ cluster_twohead_create_dataloaders(config) net = archs.__dict__[config.arch](config) if config.restart: model_path = os.path.join(config.out_dir, net_name) print("Model path: %s" % model_path) net.load_state_dict( torch.load(model_path, map_location=lambda storage, loc: storage)) net.cuda() net = torch.nn.DataParallel(net) net.train() optimiser = get_opt(config.opt)(net.module.parameters(), lr=config.lr) if config.restart: if not (given_config is not None and given_config.num_epochs == 0): print("loading latest opt") optimiser.load_state_dict( torch.load(os.path.join(config.out_dir, opt_name))) heads = ["B", "A"] if config.head_A_first: heads = ["A", "B"] head_epochs = {} head_epochs["A"] = config.head_A_epochs head_epochs["B"] = config.head_B_epochs # Results # ---------------------------------------------------------------------- if config.restart: if not config.restart_from_best: next_epoch = config.last_epoch + 1 # corresponds to last saved model else: next_epoch = np.argmax(np.array(config.epoch_acc)) + 1 print("starting from epoch %d" % next_epoch) # in case we overshot without saving config.epoch_acc = config.epoch_acc[:next_epoch] # in case we overshot config.epoch_avg_subhead_acc = config.epoch_avg_subhead_acc[: next_epoch] config.epoch_stats = config.epoch_stats[:next_epoch] if config.double_eval: config.double_eval_acc = config.double_eval_acc[:next_epoch] config.double_eval_avg_subhead_acc = config.double_eval_avg_subhead_acc[: next_epoch] config.double_eval_stats = config.double_eval_stats[:next_epoch] config.epoch_loss_head_A = config.epoch_loss_head_A[:(next_epoch - 1)] config.epoch_loss_no_lamb_head_A = config.epoch_loss_no_lamb_head_A[:( next_epoch - 1)] config.epoch_loss_head_B = config.epoch_loss_head_B[:(next_epoch - 1)] config.epoch_loss_no_lamb_head_B = config.epoch_loss_no_lamb_head_B[:( next_epoch - 1)] else: config.epoch_acc = [] config.epoch_avg_subhead_acc = [] config.epoch_stats = [] if config.double_eval: config.double_eval_acc = [] config.double_eval_avg_subhead_acc = [] config.double_eval_stats = [] config.epoch_loss_head_A = [] config.epoch_loss_no_lamb_head_A = [] config.epoch_loss_head_B = [] config.epoch_loss_no_lamb_head_B = [] subhead = None if config.select_subhead_on_loss: subhead = get_subhead_using_loss(config, dataloaders_head_B, net, sobel=False, lamb=config.lamb_B) _ = cluster_eval( config, net, mapping_assignment_dataloader=mapping_assignment_dataloader, mapping_test_dataloader=mapping_test_dataloader, sobel=False, use_subhead=subhead) print("Pre: time %s: \n %s" % (datetime.now(), nice(config.epoch_stats[-1]))) if config.double_eval: print("double eval: \n %s" % (nice(config.double_eval_stats[-1]))) sys.stdout.flush() next_epoch = 1 fig, axarr = plt.subplots(6 + 2 * int(config.double_eval), sharex=False, figsize=(20, 20)) save_progression = hasattr(config, "save_progression") and \ config.save_progression if save_progression: save_progression_count = 0 save_progress(config, net, mapping_assignment_dataloader, mapping_test_dataloader, save_progression_count, sobel=False, render_count=render_count) save_progression_count += 1 # Train # ------------------------------------------------------------------------ for e_i in xrange(next_epoch, config.num_epochs): print("Starting e_i: %d" % e_i) if e_i in config.lr_schedule: optimiser = update_lr(optimiser, lr_mult=config.lr_mult) for head_i in range(2): head = heads[head_i] if head == "A": dataloaders = dataloaders_head_A epoch_loss = config.epoch_loss_head_A epoch_loss_no_lamb = config.epoch_loss_no_lamb_head_A lamb = config.lamb_A elif head == "B": dataloaders = dataloaders_head_B epoch_loss = config.epoch_loss_head_B epoch_loss_no_lamb = config.epoch_loss_no_lamb_head_B lamb = config.lamb_B avg_loss = 0. # over heads and head_epochs (and subheads) avg_loss_no_lamb = 0. avg_loss_count = 0 for head_i_epoch in range(head_epochs[head]): sys.stdout.flush() iterators = (d for d in dataloaders) b_i = 0 for tup in itertools.izip(*iterators): net.module.zero_grad() all_imgs = torch.zeros( (config.batch_sz, config.in_channels, config.input_sz, config.input_sz)).cuda() all_imgs_tf = torch.zeros( (config.batch_sz, config.in_channels, config.input_sz, config.input_sz)).cuda() imgs_curr = tup[0][0] # always the first curr_batch_sz = imgs_curr.size(0) for d_i in xrange(config.num_dataloaders): imgs_tf_curr = tup[1 + d_i][0] # from 2nd to last assert (curr_batch_sz == imgs_tf_curr.size(0)) actual_batch_start = d_i * curr_batch_sz actual_batch_end = actual_batch_start + curr_batch_sz all_imgs[actual_batch_start:actual_batch_end, :, :, :] = \ imgs_curr.cuda() all_imgs_tf[actual_batch_start:actual_batch_end, :, :, :] = \ imgs_tf_curr.cuda() if not (curr_batch_sz == config.dataloader_batch_sz): print("last batch sz %d" % curr_batch_sz) curr_total_batch_sz = curr_batch_sz * config.num_dataloaders # # times 2 all_imgs = all_imgs[:curr_total_batch_sz, :, :, :] all_imgs_tf = all_imgs_tf[:curr_total_batch_sz, :, :, :] x_outs = net(all_imgs) x_tf_outs = net(all_imgs_tf) avg_loss_batch = None # avg over the heads avg_loss_no_lamb_batch = None for i in xrange(config.num_subheads): loss, loss_no_lamb = IID_loss(x_outs[i], x_tf_outs[i], lamb=lamb) if avg_loss_batch is None: avg_loss_batch = loss avg_loss_no_lamb_batch = loss_no_lamb else: avg_loss_batch += loss avg_loss_no_lamb_batch += loss_no_lamb avg_loss_batch /= config.num_subheads avg_loss_no_lamb_batch /= config.num_subheads if ((b_i % 100) == 0) or (e_i == next_epoch): print( "Model ind %d epoch %d head %s batch: %d avg loss %f avg loss no lamb %f time %s" % \ (config.model_ind, e_i, head, b_i, avg_loss_batch.item(), avg_loss_no_lamb_batch.item(), datetime.now())) sys.stdout.flush() if not np.isfinite(avg_loss_batch.item()): print("Loss is not finite... %s:" % avg_loss_batch.item()) exit(1) avg_loss += avg_loss_batch.item() avg_loss_no_lamb += avg_loss_no_lamb_batch.item() avg_loss_count += 1 avg_loss_batch.backward() optimiser.step() if ((b_i % 50) == 0) and save_progression: save_progress(config, net, mapping_assignment_dataloader, mapping_test_dataloader, save_progression_count, sobel=False, render_count=render_count) save_progression_count += 1 b_i += 1 if b_i == 2 and config.test_code: break avg_loss = float(avg_loss / avg_loss_count) avg_loss_no_lamb = float(avg_loss_no_lamb / avg_loss_count) epoch_loss.append(avg_loss) epoch_loss_no_lamb.append(avg_loss_no_lamb) # Eval # ----------------------------------------------------------------------- subhead = None if config.select_subhead_on_loss: subhead = get_subhead_using_loss(config, dataloaders_head_B, net, sobel=False, lamb=config.lamb_B) is_best = cluster_eval( config, net, mapping_assignment_dataloader=mapping_assignment_dataloader, mapping_test_dataloader=mapping_test_dataloader, sobel=False, use_subhead=subhead) print("Pre: time %s: \n %s" % (datetime.now(), nice(config.epoch_stats[-1]))) if config.double_eval: print("double eval: \n %s" % (nice(config.double_eval_stats[-1]))) sys.stdout.flush() axarr[0].clear() axarr[0].plot(config.epoch_acc) axarr[0].set_title("acc (best), top: %f" % max(config.epoch_acc)) axarr[1].clear() axarr[1].plot(config.epoch_avg_subhead_acc) axarr[1].set_title("acc (avg), top: %f" % max(config.epoch_avg_subhead_acc)) axarr[2].clear() axarr[2].plot(config.epoch_loss_head_A) axarr[2].set_title("Loss head A") axarr[3].clear() axarr[3].plot(config.epoch_loss_no_lamb_head_A) axarr[3].set_title("Loss no lamb head A") axarr[4].clear() axarr[4].plot(config.epoch_loss_head_B) axarr[4].set_title("Loss head B") axarr[5].clear() axarr[5].plot(config.epoch_loss_no_lamb_head_B) axarr[5].set_title("Loss no lamb head B") if config.double_eval: axarr[6].clear() axarr[6].plot(config.double_eval_acc) axarr[6].set_title("double eval acc (best), top: %f" % max(config.double_eval_acc)) axarr[7].clear() axarr[7].plot(config.double_eval_avg_subhead_acc) axarr[7].set_title("double eval acc (avg)), top: %f" % max(config.double_eval_avg_subhead_acc)) fig.tight_layout() fig.canvas.draw_idle() fig.savefig(os.path.join(config.out_dir, "plots.png")) if is_best or (e_i % config.save_freq == 0): net.module.cpu() if e_i % config.save_freq == 0: torch.save(net.module.state_dict(), os.path.join(config.out_dir, "latest_net.pytorch")) torch.save( optimiser.state_dict(), os.path.join(config.out_dir, "latest_optimiser.pytorch")) config.last_epoch = e_i # for last saved version if is_best: # also serves as backup if hardware fails - less likely to hit this torch.save(net.module.state_dict(), os.path.join(config.out_dir, "best_net.pytorch")) torch.save( optimiser.state_dict(), os.path.join(config.out_dir, "best_optimiser.pytorch")) with open(os.path.join(config.out_dir, "best_config.pickle"), 'wb') as outfile: pickle.dump(config, outfile) with open(os.path.join(config.out_dir, "best_config.txt"), "w") as text_file: text_file.write("%s" % config) net.module.cuda() with open(os.path.join(config.out_dir, "config.pickle"), 'wb') as outfile: pickle.dump(config, outfile) with open(os.path.join(config.out_dir, "config.txt"), "w") as text_file: text_file.write("%s" % config) if config.test_code: exit(0)
def setup(config): if config.mode == "IID": assert ("TwoHead" in config.arch) # Exactly one config has to match the groundtruth k and all ks have to be bigger than gt_k assert any(k == config.gt_k for k in config.output_ks) assert all(k >= config.gt_k for k in config.output_ks) config.output_k = config.gt_k config.eval_mode = "hung" config.twohead = True elif config.mode == "IID+": assert len( config.output_ks) == 1 and config.output_ks[0] >= config.gt_k config.output_k = config.output_ks[0] config.eval_mode = "orig" config.twohead = False config.double_eval = False else: raise NotImplementedError if config.sobel: if not config.include_rgb: config.in_channels = 2 else: config.in_channels = 5 else: config.in_channels = 1 config.train_partitions = [True, False] config.mapping_assignment_partitions = [True, False] config.mapping_test_partitions = [True, False] assert (config.batch_sz % config.num_dataloaders == 0) config.dataloader_batch_sz = config.batch_sz / config.num_dataloaders config.out_dir = os.path.join(config.out_root, str(config.model_ind)) if not os.path.exists(config.out_dir): os.makedirs(config.out_dir) if config.restart: config_name = "config.pickle" net_name = "latest_net.pytorch" opt_name = "latest_optimiser.pytorch" if config.restart_from_best: config_name = "best_config.pickle" net_name = "best_net.pytorch" opt_name = "best_optimiser.pytorch" given_config = config reloaded_config_path = os.path.join(given_config.out_dir, config_name) print("Loading restarting config from: %s" % reloaded_config_path) with open(reloaded_config_path, "rb") as config_f: config = pickle.load(config_f) if hasattr(config, "num_sub_heads"): config.num_subheads = config.num_sub_heads assert given_config.test_code or (config.model_ind == given_config.model_ind) config.restart = True config.restart_from_best = given_config.restart_from_best if given_config.dataset_root is not None: config.dataset_root = given_config.dataset_root if given_config.out_root is not None: config.out_root = given_config.out_root config.out_dir = os.path.join(given_config.out_root, str(given_config.model_ind)) # copy over new num_epochs and lr schedule config.num_epochs = given_config.num_epochs config.lr_schedule = given_config.lr_schedule config.save_progression = given_config.save_progression config.result_dir = given_config.result_dir if not hasattr(config, "cutout"): config.cutout = False config.cutout_p = 0.5 config.cutout_max_box = 0.5 if not hasattr(config, "batchnorm_track"): config.batchnorm_track = True # before we added in false option config.plot_cluster_stats = given_config.plot_cluster_stats if hasattr(given_config, "batch_sz") and hasattr( given_config, "num_dataloaders"): assert (given_config.batch_sz % given_config.num_dataloaders == 0) config.batch_sz = given_config.batch_sz config.num_dataloaders = given_config.num_dataloaders config.dataloader_batch_sz = config.batch_sz / config.num_dataloaders else: print("Config: %s" % config_to_str(config)) net_name = None opt_name = None net = archs.__dict__[config.arch](config) if config.restart: model_path = os.path.join(config.out_dir, net_name) print("Model path: %s" % model_path) net.load_state_dict( torch.load(model_path, map_location=lambda storage, loc: storage)) net.cuda() net = torch.nn.DataParallel(net) net.train() optimiser = get_opt(config.opt)(net.module.parameters(), lr=config.lr) if config.restart: opt_path = os.path.join(config.out_dir, opt_name) optimiser.load_state_dict(torch.load(opt_path)) return config, net, optimiser