def load_checkpoint(self, filepath, load_iternum=True, ignore_failure=True, load_optim=True): if os.path.isfile(filepath): with open(filepath, 'rb') as f: checkpoint = torch.load(f) if load_iternum: self.iter = checkpoint['iter'] for key, value in self.net_dict.items(): try: if isinstance(value, dict): state_dicts = checkpoint['model_states'][key] for sub_key, net in value.items(): value[sub_key].load_state_dict( state_dicts[sub_key], strict=False) else: value.load_state_dict(checkpoint['model_states'][key], strict=False) except Exception as e: logging.warning("Could not load {}".format(key)) logging.warning(str(e)) if not ignore_failure: raise e if load_optim: for key, value in self.optim_dict.items(): try: if isinstance(value, dict): state_dicts = checkpoint['optim_states'][key] for sub_key, net in value.items(): value[sub_key].load_state_dict( state_dicts[sub_key]) else: value.load_state_dict( checkpoint['optim_states'][key]) except Exception as e: logging.warning("Could not load {}".format(key)) logging.warning(str(e)) if not ignore_failure: raise e self.pbar.update(self.iter) # TODO: remove the following hard coded lr assumption on optim_G # Assuming optim_G to be the optimizer for the generator and the one we are interested in logging.info("Model Loaded: {} @ iter:{}, lr:{:.6f}".format( filepath, self.iter, get_lr(self.optim_dict['optim_G']))) else: logging.error("File does not exist: {}".format(filepath))
def train(net, opt, train_dataloader, val_dataloader, context, run_id): """Training function.""" if not opt.skip_pretrain_validation: validation_results = validate(net, val_dataloader, context, use_threads=opt.num_workers > 0) for name, val_acc in validation_results: logging.info('Pre-train validation: %s=%f' % (name, val_acc)) steps = parse_steps(opt.steps, opt.epochs, logging) opt_options = {'learning_rate': opt.lr, 'wd': opt.wd, 'clip_gradient': 10.} if opt.optimizer == 'sgd': opt_options['momentum'] = 0.9 if opt.optimizer == 'adam': opt_options['epsilon'] = 1e-7 trainer = mx.gluon.Trainer(net.collect_params(), opt.optimizer, opt_options, kvstore=opt.kvstore) L = DiscriminativeLoss(train_dataloader._dataset.num_classes(), len(train_dataloader._dataset)) L.initialize(ctx=context) if not opt.disable_hybridize: L.hybridize() smoothing_constant = .01 # for tracking moving losses moving_loss = 0 best_results = [] # R@1, NMI for epoch in range(1, opt.epochs + 1): p_bar = tqdm(enumerate(train_dataloader), total=len(train_dataloader), desc=('[Run %d/%d] Epoch %d' % (run_id, opt.number_of_runs, epoch))) trainer.set_learning_rate(get_lr(opt.lr, epoch, steps, opt.factor)) for i, batch in p_bar: data = mx.gluon.utils.split_and_load(batch[0], ctx_list=context, batch_axis=0, even_split=False) label = mx.gluon.utils.split_and_load(batch[1], ctx_list=context, batch_axis=0, even_split=False) negative_labels = mx.gluon.utils.split_and_load(batch[2], ctx_list=context, batch_axis=0, even_split=False) with ag.record(): losses = [] for x, y, nl in zip(data, label, negative_labels): embs = net(x) losses.append(L(embs, y, nl)) for l in losses: l.backward() trainer.step(len(losses)) # Keep a moving average of the losses curr_loss = mx.nd.mean(mx.nd.concatenate(losses)).asscalar() moving_loss = ( curr_loss if ((i == 0) and (epoch == 1)) # starting value else (1 - smoothing_constant) * moving_loss + smoothing_constant * curr_loss) # add current p_bar.set_postfix_str('Moving loss: %.4f' % moving_loss) logging.info('Moving loss: %.4f' % moving_loss) validation_results = validate(net, val_dataloader, context, use_threads=opt.num_workers > 0) for name, val_acc in validation_results: logging.info('[Epoch %d] validation: %s=%f' % (epoch, name, val_acc)) if (len(best_results) == 0) or (validation_results[0][1] > best_results[0][1]): best_results = validation_results if opt.save_model_prefix.lower() != 'none': filename = '%s.params' % opt.save_model_prefix logging.info('Saving %s.' % filename) net.save_parameters(filename) logging.info('New best validation: R@1: %f NMI: %f' % (best_results[0][1], best_results[-1][1])) return best_results
def train(net, opt, train_dataloader, val_dataloader, context, run_id): """Training function.""" if not opt.skip_pretrain_validation: validation_results = validate(net, val_dataloader, context, use_threads=opt.num_workers > 0) for name, val_acc in validation_results: logging.info('Pre-train validation: %s=%f' % (name, val_acc)) # Calculate decay steps steps = parse_steps(opt.steps, opt.epochs, logger=logging) # Init optimizer opt_options = {'learning_rate': opt.lr, 'wd': opt.wd, 'clip_gradient': 10.} if opt.optimizer == 'sgd': opt_options['momentum'] = 0.9 elif opt.optimizer in ['adam', 'radam']: opt_options['epsilon'] = opt.epsilon elif opt.optimizer == 'rmsprop': opt_options['gamma1'] = 0.9 opt_options['epsilon'] = opt.epsilon trainer = mx.gluon.Trainer(net.collect_params(), opt.optimizer, opt_options, kvstore=opt.kvstore) # Init loss function if opt.loss == 'nca': logging.info('Using NCA loss') proxyloss = ProxyNCALoss(train_dataloader._dataset.num_classes(), exclude_positives=True, label_smooth=opt.label_smooth, multiplier=opt.embedding_multiplier, temperature=opt.temperature) elif opt.loss == 'triplet': logging.info('Using triplet loss') proxyloss = ProxyTripletLoss(train_dataloader._dataset.num_classes()) elif opt.loss == 'xentropy': logging.info('Using NCA loss without excluding positives') proxyloss = ProxyNCALoss(train_dataloader._dataset.num_classes(), exclude_positives=False, label_smooth=opt.label_smooth, multiplier=opt.embedding_multiplier, temperature=opt.temperature) else: raise RuntimeError('Unknown loss function: %s' % opt.loss) smoothing_constant = .01 # for tracking moving losses moving_loss = 0 best_results = [] # R@1, NMI for epoch in range(1, opt.epochs + 1): p_bar = tqdm(enumerate(train_dataloader), total=len(train_dataloader), desc=('[Run %d/%d] Epoch %d' % (run_id, opt.number_of_runs, epoch))) new_lr = get_lr(opt.lr, epoch, steps, opt.factor) logging.info('Setting LR to %f' % new_lr) trainer.set_learning_rate(new_lr) if opt.optimizer == 'rmsprop': # exponential decay of gamma if epoch != 1: trainer._optimizer.gamma1 *= .94 logging.info('Setting rmsprop gamma to %f' % trainer._optimizer.gamma1) for (i, batch) in p_bar: if opt.iteration_per_epoch > 0: for b in range(len(batch)): batch[b] = batch[b][0] data = mx.gluon.utils.split_and_load(batch[0], ctx_list=context, batch_axis=0, even_split=False) label = mx.gluon.utils.split_and_load(batch[1], ctx_list=context, batch_axis=0, even_split=False) negative_labels = mx.gluon.utils.split_and_load(batch[2], ctx_list=context, batch_axis=0, even_split=False) with ag.record(): losses = [] for x, y, nl in zip(data, label, negative_labels): embs, positive_proxy, negative_proxies, proxies = net( x, y, nl) if opt.loss in ['nca', 'xentropy']: losses.append(proxyloss(embs, proxies, y, nl)) else: losses.append( proxyloss(embs, positive_proxy, negative_proxies)) for l in losses: l.backward() trainer.step(data[0].shape[0]) # Keep a moving average of the losses curr_loss = mx.nd.mean(mx.nd.maximum(mx.nd.concatenate(losses), 0)).asscalar() moving_loss = ( curr_loss if ((i == 0) and (epoch == 1)) # starting value else (1 - smoothing_constant) * moving_loss + smoothing_constant * curr_loss) p_bar.set_postfix_str('Moving loss: %.4f' % moving_loss) logging.info('Moving loss: %.4f' % moving_loss) validation_results = validate(net, val_dataloader, context, use_threads=opt.num_workers > 0) for name, val_acc in validation_results: logging.info('[Epoch %d] validation: %s=%f' % (epoch, name, val_acc)) if (len(best_results) == 0) or (validation_results[0][1] > best_results[0][1]): best_results = validation_results if opt.save_model_prefix.lower() != 'none': filename = '%s.params' % opt.save_model_prefix logging.info('Saving %s.' % filename) net.save_parameters(filename) logging.info('New best validation: R@1: %f NMI: %f' % (best_results[0][1], best_results[-1][1])) return best_results
def train(net, opt, train_dataloader, val_dataloader, context, run_id): """Training function.""" if not opt.skip_pretrain_validation: validation_results = validate(net, val_dataloader, context, use_threads=opt.num_workers > 0) for name, val_acc in validation_results: logging.info('Pre-train validation: %s=%f' % (name, val_acc)) steps = parse_steps(opt.steps, opt.epochs, logging) opt_options = {'learning_rate': opt.lr, 'wd': opt.wd} if opt.optimizer == 'sgd': opt_options['momentum'] = 0.9 if opt.optimizer == 'adam': opt_options['epsilon'] = 1e-7 trainer = mx.gluon.Trainer(net.collect_params(), opt.optimizer, opt_options, kvstore=opt.kvstore) L = PrototypeLoss(opt.nc, opt.ns, opt.nq) data_size = opt.nc * (opt.ns + opt.nq) best_results = [] # R@1, NMI for epoch in range(1, opt.epochs + 1): prev_loss, cumulative_loss = 0.0, 0.0 trainer.set_learning_rate(get_lr(opt.lr, epoch, steps, opt.factor)) logging.info('Epoch %d learning rate=%f', epoch, trainer.learning_rate) p_bar = tqdm(train_dataloader, desc=('[Run %d/%d] Epoch %d' % (run_id, opt.number_of_runs, epoch))) for batch in p_bar: supports_batch, queries_batch, labels_batch = [x[0] for x in batch] # supports_batch: <Nc x Ns x I> # queries_batch: <Nc x Nq x I> # labels_batch: <Nc x 1> supports_batch = mx.nd.reshape(supports_batch, (-1, 0, 0, 0), reverse=True) # <(Nc * Ns) x I> queries_batch = mx.nd.reshape(queries_batch, (-1, 0, 0, 0), reverse=True) queries = mx.gluon.utils.split_and_load(queries_batch, ctx_list=context, batch_axis=0) supports = mx.gluon.utils.split_and_load(supports_batch, ctx_list=context, batch_axis=0) support_embs = [] queries_embs = [] with ag.record(): for s in supports: s_emb = net(s) support_embs.append(s_emb) supports = mx.nd.concat(*support_embs, dim=0) # <Nc*Ns x E> for q in queries: q_emb = net(q) queries_embs.append(q_emb) queries = mx.nd.concat(*queries_embs, dim=0) # <Nc*Nq x E> loss = L(supports, queries) loss.backward() cumulative_loss += mx.nd.mean(loss).asscalar() trainer.step(data_size) p_bar.set_postfix({'loss': cumulative_loss - prev_loss}) prev_loss = cumulative_loss validation_results = validate(net, val_dataloader, context, use_threads=opt.num_workers > 0) for name, val_acc in validation_results: logging.info('[Epoch %d] validation: %s=%f' % (epoch, name, val_acc)) if (len(best_results) == 0) or (validation_results[0][1] > best_results[0][1]): best_results = validation_results if opt.save_model_prefix.lower() != 'none': filename = '%s.params' % opt.save_model_prefix logging.info('Saving %s.' % filename) net.save_parameters(filename) logging.info('New best validation: R@1: %f NMI: %f' % (best_results[0][1], best_results[-1][1])) return best_results
def train(cmd_args): if not os.path.exists(cmd_args.exp_path): os.makedirs(cmd_args.exp_path) with open(joinpath(cmd_args.exp_path, 'options.txt'), 'w') as f: param_dict = vars(cmd_args) for param in param_dict: f.write(param + ' = ' + str(param_dict[param]) + '\n') logpath = joinpath(cmd_args.exp_path, 'eval.result') param_cnt_path = joinpath(cmd_args.exp_path, 'param_count.txt') # dataset and KG dataset = Dataset(cmd_args.data_root, cmd_args.batchsize, cmd_args.shuffle_sampling, load_method=cmd_args.load_method) kg = KnowledgeGraph(dataset.fact_dict, PRED_DICT, dataset) # model if cmd_args.use_gcn == 1: gcn = GCN(kg, cmd_args.embedding_size - cmd_args.gcn_free_size, cmd_args.gcn_free_size, num_hops=cmd_args.num_hops, num_layers=cmd_args.num_mlp_layers, transductive=cmd_args.trans == 1).to(cmd_args.device) else: gcn = TrainableEmbedding(kg, cmd_args.embedding_size).to(cmd_args.device) posterior_model = FactorizedPosterior( kg, cmd_args.embedding_size, cmd_args.slice_dim).to(cmd_args.device) mln = ConditionalMLN(cmd_args, dataset.rule_ls) if cmd_args.model_load_path is not None: gcn.load_state_dict( torch.load(joinpath(cmd_args.model_load_path, 'gcn.model'))) posterior_model.load_state_dict( torch.load(joinpath(cmd_args.model_load_path, 'posterior.model'))) # optimizers monitor = EarlyStopMonitor(cmd_args.patience) all_params = chain.from_iterable( [posterior_model.parameters(), gcn.parameters()]) optimizer = optim.Adam(all_params, lr=cmd_args.learning_rate, weight_decay=cmd_args.l2_coef) scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, 'max', factor=cmd_args.lr_decay_factor, patience=cmd_args.lr_decay_patience, min_lr=cmd_args.lr_decay_min) with open(param_cnt_path, 'w') as f: cnt_gcn_params = count_parameters(gcn) cnt_posterior_params = count_parameters(posterior_model) if cmd_args.use_gcn == 1: f.write('GCN params count: %d\n' % cnt_gcn_params) elif cmd_args.use_gcn == 0: f.write('plain params count: %d\n' % cnt_gcn_params) f.write('posterior params count: %d\n' % cnt_posterior_params) f.write('Total params count: %d\n' % (cnt_gcn_params + cnt_posterior_params)) if cmd_args.no_train == 1: cmd_args.num_epochs = 0 # for Freebase data if cmd_args.load_method == 1: # prepare data for M-step tqdm.write('preparing data for M-step...') pred_arg1_set_arg2 = dict() pred_arg2_set_arg1 = dict() pred_fact_set = dict() for pred in dataset.fact_dict_2: pred_arg1_set_arg2[pred] = dict() pred_arg2_set_arg1[pred] = dict() pred_fact_set[pred] = set() for _, args in dataset.fact_dict_2[pred]: if args[0] not in pred_arg1_set_arg2[pred]: pred_arg1_set_arg2[pred][args[0]] = set() if args[1] not in pred_arg2_set_arg1[pred]: pred_arg2_set_arg1[pred][args[1]] = set() pred_arg1_set_arg2[pred][args[0]].add(args[1]) pred_arg2_set_arg1[pred][args[1]].add(args[0]) pred_fact_set[pred].add(args) grounded_rules = [] for rule_idx, rule in enumerate(dataset.rule_ls): grounded_rules.append(set()) body_atoms = [] head_atom = None for atom in rule.atom_ls: if atom.neg: body_atoms.append(atom) elif head_atom is None: head_atom = atom # atom in body must be observed assert len(body_atoms) <= 2 if len(body_atoms) > 0: body1 = body_atoms[0] for _, body1_args in dataset.fact_dict_2[body1.pred_name]: var2arg = dict() var2arg[body1.var_name_ls[0]] = body1_args[0] var2arg[body1.var_name_ls[1]] = body1_args[1] for body2 in body_atoms[1:]: if body2.var_name_ls[0] in var2arg: if var2arg[body2.var_name_ls[ 0]] in pred_arg1_set_arg2[body2.pred_name]: for body2_arg2 in pred_arg1_set_arg2[ body2.pred_name][var2arg[ body2.var_name_ls[0]]]: var2arg[body2.var_name_ls[1]] = body2_arg2 grounded_rules[rule_idx].add( tuple(sorted(var2arg.items()))) elif body2.var_name_ls[1] in var2arg: if var2arg[body2.var_name_ls[ 1]] in pred_arg2_set_arg1[body2.pred_name]: for body2_arg1 in pred_arg2_set_arg1[ body2.pred_name][var2arg[ body2.var_name_ls[1]]]: var2arg[body2.var_name_ls[0]] = body2_arg1 grounded_rules[rule_idx].add( tuple(sorted(var2arg.items()))) # Collect head atoms derived by grounded formulas grounded_obs = dict() grounded_hid = dict() grounded_hid_score = dict() cnt_hid = 0 for rule_idx in range(len(dataset.rule_ls)): rule = dataset.rule_ls[rule_idx] for var2arg in grounded_rules[rule_idx]: var2arg = dict(var2arg) head_atom = rule.atom_ls[-1] assert not head_atom.neg # head atom pred = head_atom.pred_name args = (var2arg[head_atom.var_name_ls[0]], var2arg[head_atom.var_name_ls[1]]) if args in pred_fact_set[pred]: if (pred, args) not in grounded_obs: grounded_obs[(pred, args)] = [] grounded_obs[(pred, args)].append(rule_idx) else: if (pred, args) not in grounded_hid: grounded_hid[(pred, args)] = [] grounded_hid[(pred, args)].append(rule_idx) tqdm.write('observed: %d, hidden: %d' % (len(grounded_obs), len(grounded_hid))) # Aggregate atoms by predicates for fast inference pred_aggregated_hid = dict() pred_aggregated_hid_args = dict() for (pred, args) in grounded_hid: if pred not in pred_aggregated_hid: pred_aggregated_hid[pred] = [] if pred not in pred_aggregated_hid_args: pred_aggregated_hid_args[pred] = [] pred_aggregated_hid[pred].append( (dataset.const2ind[args[0]], dataset.const2ind[args[1]])) pred_aggregated_hid_args[pred].append(args) pred_aggregated_hid_list = [[ pred, pred_aggregated_hid[pred] ] for pred in sorted(pred_aggregated_hid.keys())] for current_epoch in range(cmd_args.num_epochs): # E-step: optimize the parameters in the posterior model num_batches = int( math.ceil(len(dataset.test_fact_ls) / cmd_args.batchsize)) pbar = tqdm(total=num_batches) acc_loss = 0.0 cur_batch = 0 for samples_by_r, latent_mask_by_r, neg_mask_by_r, obs_var_by_r, neg_var_by_r in \ dataset.get_batch_by_q(cmd_args.batchsize): node_embeds = gcn(dataset) loss = 0.0 r_cnt = 0 for ind, samples in enumerate(samples_by_r): neg_mask = neg_mask_by_r[ind] latent_mask = latent_mask_by_r[ind] obs_var = obs_var_by_r[ind] neg_var = neg_var_by_r[ind] if sum([len(e[1]) for e in neg_mask]) == 0: continue potential, posterior_prob, obs_xent = posterior_model( [samples, neg_mask, latent_mask, obs_var, neg_var], node_embeds, fast_mode=True) if cmd_args.no_entropy == 1: entropy = 0 else: entropy = compute_entropy( posterior_prob) / cmd_args.entropy_temp loss += -(potential.sum() * dataset.rule_ls[ind].weight + entropy) / (potential.size(0) + 1e-6) + obs_xent r_cnt += 1 if r_cnt > 0: loss /= r_cnt acc_loss += loss.item() optimizer.zero_grad() loss.backward() optimizer.step() pbar.update() cur_batch += 1 pbar.set_description( 'Epoch %d, train loss: %.4f, lr: %.4g' % (current_epoch, acc_loss / cur_batch, get_lr(optimizer))) # M-step: optimize the weights of logic rules with torch.no_grad(): posterior_prob = posterior_model(pred_aggregated_hid_list, node_embeds, fast_inference_mode=True) for pred_i, (pred, var_ls) in enumerate(pred_aggregated_hid_list): for var_i, var in enumerate(var_ls): args = pred_aggregated_hid_args[pred][var_i] grounded_hid_score[( pred, args)] = posterior_prob[pred_i][var_i] rule_weight_gradient = torch.zeros(len(dataset.rule_ls)) for (pred, args) in grounded_obs: for rule_idx in set(grounded_obs[(pred, args)]): rule_weight_gradient[ rule_idx] += 1.0 - compute_MB_proba( dataset.rule_ls, grounded_obs[(pred, args)]) for (pred, args) in grounded_hid: for rule_idx in set(grounded_hid[(pred, args)]): target = grounded_hid_score[(pred, args)] rule_weight_gradient[ rule_idx] += target - compute_MB_proba( dataset.rule_ls, grounded_hid[(pred, args)]) for rule_idx, rule in enumerate(dataset.rule_ls): rule.weight += cmd_args.learning_rate_rule_weights * rule_weight_gradient[ rule_idx] print(dataset.rule_ls[rule_idx].weight, end=' ') pbar.close() # validation with torch.no_grad(): node_embeds = gcn(dataset) valid_loss = 0.0 cnt_batch = 0 for samples_by_r, latent_mask_by_r, neg_mask_by_r, obs_var_by_r, neg_var_by_r in \ dataset.get_batch_by_q(cmd_args.batchsize, validation=True): loss = 0.0 r_cnt = 0 for ind, samples in enumerate(samples_by_r): neg_mask = neg_mask_by_r[ind] latent_mask = latent_mask_by_r[ind] obs_var = obs_var_by_r[ind] neg_var = neg_var_by_r[ind] if sum([len(e[1]) for e in neg_mask]) == 0: continue valid_potential, valid_prob, valid_obs_xent = posterior_model( [samples, neg_mask, latent_mask, obs_var, neg_var], node_embeds, fast_mode=True) if cmd_args.no_entropy == 1: valid_entropy = 0 else: valid_entropy = compute_entropy( valid_prob) / cmd_args.entropy_temp loss += -(valid_potential.sum() + valid_entropy) / ( valid_potential.size(0) + 1e-6) + valid_obs_xent r_cnt += 1 if r_cnt > 0: loss /= r_cnt valid_loss += loss.item() cnt_batch += 1 tqdm.write('Epoch %d, valid loss: %.4f' % (current_epoch, valid_loss / cnt_batch)) should_stop = monitor.update(valid_loss) scheduler.step(valid_loss) is_current_best = monitor.cnt == 0 if is_current_best: savepath = joinpath(cmd_args.exp_path, 'saved_model') os.makedirs(savepath, exist_ok=True) torch.save(gcn.state_dict(), joinpath(savepath, 'gcn.model')) torch.save(posterior_model.state_dict(), joinpath(savepath, 'posterior.model')) should_stop = should_stop or (current_epoch + 1 == cmd_args.num_epochs) if should_stop: tqdm.write('Early stopping') break # ======================= generate rank list ======================= node_embeds = gcn(dataset) pbar = tqdm(total=len(dataset.test_fact_ls)) pbar.write('*' * 10 + ' Evaluation ' + '*' * 10) rrank = 0.0 hits = 0.0 cnt = 0 rrank_pred = dict([(pred_name, 0.0) for pred_name in PRED_DICT]) hits_pred = dict([(pred_name, 0.0) for pred_name in PRED_DICT]) cnt_pred = dict([(pred_name, 0.0) for pred_name in PRED_DICT]) for pred_name, X, invX, sample in gen_eval_query(dataset, const2ind=kg.ent2idx): x_mat = np.array(X) invx_mat = np.array(invX) sample_mat = np.array(sample) tail_score, head_score, true_score = posterior_model( [pred_name, x_mat, invx_mat, sample_mat], node_embeds, batch_mode=True) rank = torch.sum(tail_score >= true_score).item() + 1 rrank += 1.0 / rank hits += 1 if rank <= 10 else 0 rrank_pred[pred_name] += 1.0 / rank hits_pred[pred_name] += 1 if rank <= 10 else 0 rank = torch.sum(head_score >= true_score).item() + 1 rrank += 1.0 / rank hits += 1 if rank <= 10 else 0 rrank_pred[pred_name] += 1.0 / rank hits_pred[pred_name] += 1 if rank <= 10 else 0 cnt_pred[pred_name] += 2 cnt += 2 if cnt % 100 == 0: with open(logpath, 'w') as f: f.write('%i sample eval\n' % cnt) f.write('mmr %.4f\n' % (rrank / cnt)) f.write('hits %.4f\n' % (hits / cnt)) f.write('\n') for pred_name in PRED_DICT: if cnt_pred[pred_name] == 0: continue f.write('mmr %s %.4f\n' % (pred_name, rrank_pred[pred_name] / cnt_pred[pred_name])) f.write('hits %s %.4f\n' % (pred_name, hits_pred[pred_name] / cnt_pred[pred_name])) pbar.update() with open(logpath, 'w') as f: f.write('complete\n') f.write('mmr %.4f\n' % (rrank / cnt)) f.write('hits %.4f\n' % (hits / cnt)) f.write('\n') tqdm.write('mmr %.4f\n' % (rrank / cnt)) tqdm.write('hits %.4f\n' % (hits / cnt)) for pred_name in PRED_DICT: if cnt_pred[pred_name] == 0: continue f.write( 'mmr %s %.4f\n' % (pred_name, rrank_pred[pred_name] / cnt_pred[pred_name])) f.write( 'hits %s %.4f\n' % (pred_name, hits_pred[pred_name] / cnt_pred[pred_name])) os.system( 'mv %s %s' % (logpath, joinpath( cmd_args.exp_path, 'performance_hits_%.4f_mmr_%.4f.txt' % ((hits / cnt), (rrank / cnt))))) pbar.close() # for Kinship / UW-CSE / Cora data elif cmd_args.load_method == 0: for current_epoch in range(cmd_args.num_epochs): pbar = tqdm(range(cmd_args.num_batches)) acc_loss = 0.0 for k in pbar: node_embeds = gcn(dataset) batch_neg_mask, flat_list, batch_latent_var_inds, observed_rule_cnts, batch_observed_vars = dataset.get_batch_rnd( observed_prob=cmd_args.observed_prob, filter_latent=cmd_args.filter_latent == 1, closed_world=cmd_args.closed_world == 1, filter_observed=1) posterior_prob = posterior_model(flat_list, node_embeds) if cmd_args.no_entropy == 1: entropy = 0 else: entropy = compute_entropy( posterior_prob) / cmd_args.entropy_temp entropy = entropy.to('cpu') posterior_prob = posterior_prob.to('cpu') potential = mln(batch_neg_mask, batch_latent_var_inds, observed_rule_cnts, posterior_prob, flat_list, batch_observed_vars) optimizer.zero_grad() loss = -(potential + entropy) / cmd_args.batchsize acc_loss += loss.item() loss.backward() optimizer.step() pbar.set_description('train loss: %.4f, lr: %.4g' % (acc_loss / (k + 1), get_lr(optimizer))) # test node_embeds = gcn(dataset) with torch.no_grad(): posterior_prob = posterior_model( [(e[1], e[2]) for e in dataset.test_fact_ls], node_embeds) posterior_prob = posterior_prob.to('cpu') label = np.array([e[0] for e in dataset.test_fact_ls]) test_log_prob = float( np.sum( np.log( np.clip( np.abs((1 - label) - posterior_prob.numpy()), 1e-6, 1 - 1e-6)))) auc_roc = roc_auc_score(label, posterior_prob.numpy()) auc_pr = average_precision_score(label, posterior_prob.numpy()) tqdm.write( 'Epoch: %d, train loss: %.4f, test auc-roc: %.4f, test auc-pr: %.4f, test log prob: %.4f' % (current_epoch, acc_loss / cmd_args.num_batches, auc_roc, auc_pr, test_log_prob)) # tqdm.write(str(posterior_prob[:10])) # validation for early stop valid_sample = [] valid_label = [] for pred_name in dataset.valid_dict_2: for val, consts in dataset.valid_dict_2[pred_name]: valid_sample.append((pred_name, consts)) valid_label.append(val) valid_label = np.array(valid_label) valid_prob = posterior_model(valid_sample, node_embeds) valid_prob = valid_prob.to('cpu') valid_log_prob = float( np.sum( np.log( np.clip(np.abs((1 - valid_label) - valid_prob.numpy()), 1e-6, 1 - 1e-6)))) # tqdm.write('epoch: %d, valid log prob: %.4f' % (current_epoch, valid_log_prob)) # # should_stop = monitor.update(-valid_log_prob) # scheduler.step(valid_log_prob) # # is_current_best = monitor.cnt == 0 # if is_current_best: # savepath = joinpath(cmd_args.exp_path, 'saved_model') # os.makedirs(savepath, exist_ok=True) # torch.save(gcn.state_dict(), joinpath(savepath, 'gcn.model')) # torch.save(posterior_model.state_dict(), joinpath(savepath, 'posterior.model')) # # should_stop = should_stop or (current_epoch + 1 == cmd_args.num_epochs) # # if should_stop: # tqdm.write('Early stopping') # break # evaluation after training node_embeds = gcn(dataset) with torch.no_grad(): posterior_prob = posterior_model([(e[1], e[2]) for e in dataset.test_fact_ls], node_embeds) posterior_prob = posterior_prob.to('cpu') label = np.array([e[0] for e in dataset.test_fact_ls]) test_log_prob = float( np.sum( np.log( np.clip(np.abs((1 - label) - posterior_prob.numpy()), 1e-6, 1 - 1e-6)))) auc_roc = roc_auc_score(label, posterior_prob.numpy()) auc_pr = average_precision_score(label, posterior_prob.numpy()) tqdm.write( 'test auc-roc: %.4f, test auc-pr: %.4f, test log prob: %.4f' % (auc_roc, auc_pr, test_log_prob))
def log_save(self, **kwargs): self.step() # don't log anything if running on the aicrowd_server if self.on_aicrowd_server: return # save a checkpoint every ckpt_save_iter if is_time_for(self.iter, self.ckpt_save_iter): self.save_checkpoint() if is_time_for(self.iter, self.print_iter): msg = '[{}:{}] '.format(self.epoch, self.iter) for key, value in kwargs.get(c.LOSS, dict()).items(): msg += '{}_{}={:.3f} '.format(c.LOSS, key, value) for key, value in kwargs.get(c.ACCURACY, dict()).items(): msg += '{}_{}={:.3f} '.format(c.ACCURACY, key, value) self.pbar.write(msg) # visualize the reconstruction of the current batch every recon_iter if is_time_for(self.iter, self.recon_iter): self.visualize_recon(kwargs[c.INPUT_IMAGE], kwargs[c.RECON_IMAGE]) # traverse the latent factors every traverse_iter if is_time_for(self.iter, self.traverse_iter): self.visualize_traverse(limit=(self.traverse_min, self.traverse_max), spacing=self.traverse_spacing) # if any evaluation is included in args.evaluate_metric, evaluate every evaluate_iter if self.evaluation_metric and is_time_for(self.iter, self.evaluate_iter): self.evaluate_results = evaluate_disentanglement_metric( self, metric_names=self.evaluation_metric) # log scalar values using wandb if is_time_for(self.iter, self.float_iter): # average results for key, value in self.info_cumulative.items(): self.info_cumulative[key] /= self.float_iter # other values to log self.info_cumulative[c.ITERATION] = self.iter self.info_cumulative[c.LEARNING_RATE] = get_lr( self.optim_dict['optim_G']) # assuming we want optim_G # todo: not happy with this architecture for logging... should make it easier to add new variables to log if self.evaluation_metric: for key, value in self.evaluate_results.items(): self.info_cumulative[key] = value if self.use_wandb: import wandb wandb.log(self.info_cumulative, step=self.iter) # empty info_cumulative for key, value in self.info_cumulative.items(): self.info_cumulative[key] = 0 else: # accumulate results for key, value in kwargs.items(): if isinstance(value, float): self.info_cumulative[ key] = value + self.info_cumulative.get(key, 0) elif isinstance(value, dict): for subkey, subvalue in value.items(): complex_key = key + '/' + subkey self.info_cumulative[complex_key] = float( subvalue) + self.info_cumulative.get( complex_key, 0) # update schedulers if is_time_for(self.iter, self.schedulers_iter): self.schedulers_step( kwargs.get(c.LOSS, dict()).get(c.TOTAL_VAE_EPOCH, 0), self.iter // self.schedulers_iter)
def train(net, opt, train_data, val_data, num_train_classes, context, run_id): """Training function""" if not opt.skip_pretrain_validation: validation_results = validate(net, val_data, context, binarize=opt.binarize, nmi=opt.nmi, similarity=opt.similarity) for name, val_acc in validation_results: logging.info('Pre-train validation: %s=%f' % (name, val_acc)) # Calculate decay steps steps = parse_steps(opt.steps, opt.epochs, logger=logging) # Init optimizer opt_options = {'learning_rate': opt.lr, 'wd': opt.wd, 'clip_gradient': 10.} if opt.optimizer == 'sgd': opt_options['momentum'] = 0.9 elif opt.optimizer == 'adam': opt_options['epsilon'] = opt.epsilon elif opt.optimizer == 'rmsprop': opt_options['gamma1'] = 0.9 opt_options['epsilon'] = opt.epsilon # We train only embedding and proxies initially params2train = net.encoder.collect_params() if not opt.static_proxies: params2train.update(net.proxies.collect_params()) trainer = mx.gluon.Trainer(params2train, opt.optimizer, opt_options, kvstore=opt.kvstore) smoothing_constant = .01 # for tracking moving losses moving_loss = 0 best_results = [] # R@1, NMI batch_size = opt.batch_size * len(context) proxyloss = ProxyXentropyLoss(num_train_classes, label_smooth=opt.label_smooth, temperature=opt.temperature) for epoch in range(opt.start_epoch, opt.epochs + 1): if epoch == 2: # switch training to all parameters logging.info('Switching to train all parameters') trainer = mx.gluon.Trainer(net.collect_params(), opt.optimizer, opt_options, kvstore=opt.kvstore) if opt.batch_k > 0: iterations_per_epoch = int(ceil(train_data.num_training_images() / batch_size)) p_bar = tqdm(range(iterations_per_epoch), desc='[Run %d/%d] Epoch %d' % (run_id, opt.number_of_runs, epoch), total=iterations_per_epoch) else: p_bar = tqdm(enumerate(train_data), total=len(train_data), desc=('[Run %d/%d] Epoch %d' % (run_id, opt.number_of_runs, epoch))) new_lr = get_lr(opt.lr, epoch, steps, opt.factor) logging.info('Setting LR to %f' % new_lr) trainer.set_learning_rate(new_lr) if opt.optimizer == 'rmsprop': # exponential decay of gamma if epoch != 1: trainer._optimizer.gamma1 *= .94 logging.info('Setting rmsprop gamma to %f' % trainer._optimizer.gamma1) losses = [] curr_losses_np = [] for i in p_bar: if opt.batch_k > 0: num_sampled_classes = batch_size // opt.batch_k batch = train_data.next_proxy_sample(sampled_classes=num_sampled_classes, chose_classes_randomly=True).data else: batch = i[1] i = i[0] data = mx.gluon.utils.split_and_load(batch[0], ctx_list=context, batch_axis=0, even_split=False) label = mx.gluon.utils.split_and_load(batch[1], ctx_list=context, batch_axis=0, even_split=False) with ag.record(): for x, y in zip(data, label): embs, proxies = net(x) curr_loss = proxyloss(embs, proxies, y) losses.append(curr_loss) mx.nd.waitall() curr_losses_np += [cl.asnumpy() for cl in losses] ag.backward(losses) trainer.step(batch[0].shape[0]) # Keep a moving average of the losses curr_loss = np.mean(np.maximum(np.concatenate(curr_losses_np), 0)) curr_losses_np.clear() losses.clear() moving_loss = (curr_loss if ((i == 0) and (epoch == 1)) # starting value else (1 - smoothing_constant) * moving_loss + smoothing_constant * curr_loss) p_bar.set_postfix_str('Moving loss: %.4f' % moving_loss) logging.info('Moving loss: %.4f' % moving_loss) validation_results = validate(net, val_data, context, binarize=opt.binarize, nmi=opt.nmi, similarity=opt.similarity) for name, val_acc in validation_results: logging.info('[Epoch %d] validation: %s=%f' % (epoch, name, val_acc)) if (len(best_results) == 0) or (validation_results[0][1] > best_results[0][1]): best_results = validation_results filename = '%s.params' % opt.save_model_prefix logging.info('Saving %s.' % filename) net.save_parameters(filename) logging.info('New best validation: R@1: %f%s' % (best_results[0][1], (' NMI: %f' % best_results[-1][1]) if opt.nmi else '')) return best_results
def train_dreml(opt): logging.info(opt) # Set random seed mx.random.seed(opt.seed) np.random.seed(opt.seed) # Setup computation context context = get_context(opt.gpus, logging) cpu_ctx = mx.cpu() # Adjust batch size to each compute context batch_size = opt.batch_size * len(context) if opt.model == 'inception-bn': scale_image_data = False elif opt.model in ['resnet50_v2', 'resnet18_v2']: scale_image_data = True else: raise RuntimeError('Unsupported model: %s' % opt.model) # Prepare datasets train_dataset, val_dataset = get_dataset(opt.dataset, opt.data_path, data_shape=opt.data_shape, use_crops=opt.use_crops, use_aug=True, with_proxy=True, scale_image_data=scale_image_data, resize_img=int(opt.data_shape * 1.1)) # Create class mapping mapping = np.random.randint(0, opt.D, (opt.L, train_dataset.num_classes())) # Train embedding functions one by one trained_models = [] best_results = [] # R@1, NMI for ens in tqdm(range(opt.L), desc='Training model in ensemble'): train_dataset.set_class_mapping(mapping[ens], opt.D) train_dataloader = mx.gluon.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=opt.num_workers, last_batch='rollover') if opt.model == 'inception-bn': feature_net, feature_params = get_feature_model(opt.model, ctx=context) elif opt.model == 'resnet50_v2': feature_net = mx.gluon.model_zoo.vision.resnet50_v2(pretrained=True, ctx=context).features elif opt.model == 'resnet18_v2': feature_net = mx.gluon.model_zoo.vision.resnet18_v2(pretrained=True, ctx=context).features else: raise RuntimeError('Unsupported model: %s' % opt.model) if opt.static_proxies: net = EmbeddingNet(feature_net, opt.D, normalize=False) else: net = ProxyNet(feature_net, opt.D, num_classes=opt.D) # Init loss function if opt.static_proxies: logging.info('Using static proxies') proxyloss = StaticProxyLoss(opt.D) elif opt.loss == 'nca': logging.info('Using NCA loss') proxyloss = ProxyNCALoss(opt.D, exclude_positives=True, label_smooth=opt.label_smooth, multiplier=opt.embedding_multiplier) elif opt.loss == 'triplet': logging.info('Using triplet loss') proxyloss = ProxyTripletLoss(opt.D) elif opt.loss == 'xentropy': logging.info('Using NCA loss without excluding positives') proxyloss = ProxyNCALoss(opt.D, exclude_positives=False, label_smooth=opt.label_smooth, multiplier=opt.embedding_multiplier) else: raise RuntimeError('Unknown loss function: %s' % opt.loss) # Init optimizer opt_options = {'learning_rate': opt.lr, 'wd': opt.wd} if opt.optimizer == 'sgd': opt_options['momentum'] = 0.9 elif opt.optimizer == 'adam': opt_options['epsilon'] = opt.epsilon elif opt.optimizer == 'rmsprop': opt_options['gamma1'] = 0.9 opt_options['epsilon'] = opt.epsilon # Calculate decay steps steps = parse_steps(opt.steps, opt.epochs, logger=logging) # reset networks if opt.model == 'inception-bn': net.base_net.collect_params().load(feature_params, ctx=context, ignore_extra=True) elif opt.model in ['resnet18_v2', 'resnet50_v2']: net.base_net = mx.gluon.model_zoo.vision.get_model(opt.model, pretrained=True, ctx=context).features else: raise NotImplementedError('Unknown model: %s' % opt.model) if opt.static_proxies: net.init(mx.init.Xavier(magnitude=0.2), ctx=context, init_basenet=False) elif opt.loss == 'triplet': net.encoder.initialize(mx.init.Xavier(magnitude=0.2), ctx=context, force_reinit=True) net.proxies.initialize(mx.init.Xavier(magnitude=0.2), ctx=context, force_reinit=True) else: net.init(TruncNorm(stdev=0.001), ctx=context, init_basenet=False) if not opt.disable_hybridize: net.hybridize() trainer = mx.gluon.Trainer(net.collect_params(), opt.optimizer, opt_options, kvstore=opt.kvstore) smoothing_constant = .01 # for tracking moving losses moving_loss = 0 for epoch in range(1, opt.epochs + 1): p_bar = tqdm(enumerate(train_dataloader), total=len(train_dataloader), desc=('[Model %d/%d] Epoch %d' % (ens + 1, opt.L, epoch))) new_lr = get_lr(opt.lr, epoch, steps, opt.factor) logging.info('Setting LR to %f' % new_lr) trainer.set_learning_rate(new_lr) for i, batch in p_bar: data = mx.gluon.utils.split_and_load(batch[0], ctx_list=context, batch_axis=0, even_split=False) label = mx.gluon.utils.split_and_load(batch[1], ctx_list=context, batch_axis=0, even_split=False) negative_labels = mx.gluon.utils.split_and_load(batch[2], ctx_list=context, batch_axis=0, even_split=False) with ag.record(): losses = [] for x, y, nl in zip(data, label, negative_labels): if opt.static_proxies: embs = net(x) losses.append(proxyloss(embs, y)) else: embs, positive_proxy, negative_proxies, proxies = net(x, y, nl) if opt.loss in ['nca', 'xentropy']: losses.append(proxyloss(embs, proxies, y, nl)) else: losses.append(proxyloss(embs, positive_proxy, negative_proxies)) for l in losses: l.backward() trainer.step(data[0].shape[0]) ########################## # Keep a moving average of the losses ########################## curr_loss = mx.nd.mean(mx.nd.maximum(mx.nd.concatenate(losses), 0)).asscalar() moving_loss = (curr_loss if ((i == 0) and (epoch == 1)) # starting value else (1 - smoothing_constant) * moving_loss + smoothing_constant * curr_loss) p_bar.set_postfix_str('Moving loss: %.4f' % moving_loss) logging.info('Moving loss: %.4f' % moving_loss) # move model to CPU mx.nd.waitall() net.collect_params().reset_ctx(cpu_ctx) trained_models.append(net) del train_dataloader # Run ensemble validation logging.info('Running validation with %d models in the ensemble' % len(trained_models)) val_dataloader = mx.gluon.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=opt.num_workers, last_batch='keep') validation_results = validate(val_dataloader, trained_models, context, opt.static_proxies) for name, val_acc in validation_results: logging.info('Validation: %s=%f' % (name, val_acc)) if (len(best_results) == 0) or (validation_results[0][1] > best_results[0][1]): best_results = validation_results logging.info('New best validation: R@1: %f NMI: %f' % (best_results[0][1], best_results[-1][1]))
def train(net, opt, train_dataloader, val_dataloader, context, run_id): """Training function.""" if not opt.skip_pretrain_validation: validation_results = validate(net, val_dataloader, context, use_threads=opt.num_workers > 0) for name, val_acc in validation_results: logging.info('Pre-train validation: %s=%f' % (name, val_acc)) steps = parse_steps(opt.steps, opt.epochs, logging) opt_options = {'learning_rate': opt.lr, 'wd': opt.wd, 'clip_gradient': 10.} if opt.optimizer == 'sgd': opt_options['momentum'] = 0.9 if opt.optimizer == 'adam': opt_options['epsilon'] = 1e-7 if opt.decrease_cnn_lr: logging.info('Setting embedding LR to %f' % (10.0 * opt.lr)) for p, v in net.encoder.collect_params().items(): v.lr_mult = 10.0 trainer = mx.gluon.Trainer(net.collect_params(), opt.optimizer, opt_options, kvstore=opt.kvstore) if opt.angular_lambda > 0: # Use NPair and Angular loss together, l2 regularization is 0 for angular in this case L = AngluarLoss(alpha=np.deg2rad(opt.alpha), l2_reg=0, symmetric=opt.symmetric_loss) L2 = NPairsLoss(l2_reg=opt.l2reg_weight, symmetric=opt.symmetric_loss) if not opt.disable_hybridize: L2.hybridize() else: L = AngluarLoss(alpha=np.deg2rad(opt.alpha), l2_reg=opt.l2reg_weight, symmetric=opt.symmetric_loss) if not opt.disable_hybridize: L.hybridize() best_results = [] # R@1, NMI for epoch in range(1, opt.epochs + 1): prev_loss, cumulative_loss = 0.0, 0.0 # Learning rate schedule. trainer.set_learning_rate(get_lr(opt.lr, epoch, steps, opt.factor)) logging.info('Epoch %d learning rate=%f', epoch, trainer.learning_rate) p_bar = tqdm(train_dataloader, desc=('[Run %d/%d] Epoch %d' % (run_id, opt.number_of_runs, epoch))) for batch in p_bar: anchors_batch = batch[0][0] # <N x I> positives_batch = batch[1][0] # <N x I> anchors = mx.gluon.utils.split_and_load(anchors_batch, ctx_list=context, batch_axis=0) positives = mx.gluon.utils.split_and_load(positives_batch, ctx_list=context, batch_axis=0) labels_batch = mx.gluon.utils.split_and_load(batch[2][0], ctx_list=context, batch_axis=0) anchor_embs = [] positives_embs = [] with ag.record(): for a, p in zip(anchors, positives): a_emb = net(a) p_emb = net(p) anchor_embs.append(a_emb) positives_embs.append(p_emb) anchors = mx.nd.concat(*anchor_embs, dim=0) positives = mx.nd.concat(*positives_embs, dim=0) if opt.angular_lambda > 0: angular_loss = L(anchors, positives, labels_batch[0]) npairs_loss = L2(anchors, positives, labels_batch[0]) loss = npairs_loss + (opt.angular_lambda * angular_loss) else: loss = L(anchors, positives, labels_batch[0]) loss.backward() cumulative_loss += mx.nd.mean(loss).asscalar() trainer.step(opt.batch_size) p_bar.set_postfix({'loss': cumulative_loss - prev_loss}) prev_loss = cumulative_loss validation_results = validate(net, val_dataloader, context, use_threads=opt.num_workers > 0) for name, val_acc in validation_results: logging.info('[Epoch %d] validation: %s=%f' % (epoch, name, val_acc)) if (len(best_results) == 0) or (validation_results[0][1] > best_results[0][1]): best_results = validation_results if opt.save_model_prefix.lower() != 'none': filename = '%s.params' % opt.save_model_prefix logging.info('Saving %s.' % filename) net.save_parameters(filename) logging.info('New best validation: R@1: %f NMI: %f' % (best_results[0][1], best_results[-1][1])) return best_results
def train(net, opt, train_dataloader, val_dataloader, context, run_id): """Training function.""" if not opt.skip_pretrain_validation: validation_results = validate(net, val_dataloader, context, use_threads=opt.num_workers > 0, nmi=opt.nmi) for name, val_acc in validation_results: logging.info('Pre-train validation: %s=%f' % (name, val_acc)) steps = parse_steps(opt.steps, opt.epochs, logging) opt_options = { 'learning_rate': opt.lr, 'wd': opt.wd, } if opt.optimizer == 'sgd': opt_options['momentum'] = 0.9 if opt.optimizer == 'adam': opt_options['epsilon'] = 1e-7 trainer = mx.gluon.Trainer(net.collect_params(), opt.optimizer, opt_options, kvstore=opt.kvstore) L = RankedListLoss(margin=opt.margin, alpha=opt.alpha, temperature=opt.temperature) if not opt.disable_hybridize: L.hybridize() smoothing_constant = .01 # for tracking moving losses moving_loss = 0 best_results = [] # R@1, NMI for epoch in range(1, opt.epochs + 1): p_bar = tqdm(enumerate(train_dataloader), desc='[Run %d/%d] Epoch %d' % (run_id, opt.number_of_runs, epoch), total=len(train_dataloader)) trainer.set_learning_rate(get_lr(opt.lr, epoch, steps, opt.factor)) for i, (data, labels) in p_bar: data = data[0].as_in_context(context[0]) labels = labels[0].astype('int32').as_in_context(context[0]) with ag.record(): losses = [] embs = net(data) losses.append(L(embs, labels)) for l in losses: l.backward() trainer.step(1) # Keep a moving average of the losses curr_loss = mx.nd.mean(mx.nd.concatenate(losses)).asscalar() moving_loss = ( curr_loss if ((i == 0) and (epoch == 1)) # starting value else (1 - smoothing_constant) * moving_loss + smoothing_constant * curr_loss) # add current p_bar.set_postfix_str('Moving loss: %.4f' % moving_loss) logging.info('Moving loss: %.4f' % moving_loss) validation_results = validate(net, val_dataloader, context, use_threads=opt.num_workers > 0, nmi=opt.nmi) for name, val_acc in validation_results: logging.info('[Epoch %d] validation: %s=%f' % (epoch, name, val_acc)) if (len(best_results) == 0) or (validation_results[0][1] > best_results[0][1]): best_results = validation_results if opt.save_model_prefix.lower() != 'none': filename = '%s.params' % opt.save_model_prefix logging.info('Saving %s.' % filename) net.save_parameters(filename) logging.info('New best validation: R@1: %f NMI: %f' % (best_results[0][1], best_results[-1][1])) return best_results
def train(net, opt, train_dataloader, val_dataloader, context, run_id): """Training function.""" if not opt.skip_pretrain_validation: validation_results = validate(net, val_dataloader, context, use_threads=opt.num_workers > 0) for name, val_acc in validation_results: logging.info('Pre-train validation: %s=%f' % (name, val_acc)) steps = parse_steps(opt.steps, opt.epochs, logging) opt_options = {'learning_rate': opt.lr, 'wd': opt.wd} if opt.optimizer == 'sgd': opt_options['momentum'] = 0.9 elif opt.optimizer == 'adam': opt_options['epsilon'] = opt.epsilon elif opt.optimizer == 'rmsprop': opt_options['gamma1'] = 0.94 opt_options['epsilon'] = opt.epsilon if opt.decrease_cnn_lr: logging.info('Setting embedding LR to %f' % (10.0 * opt.lr)) for p, v in net.encoder.collect_params().items(): v.lr_mult = 10.0 trainer = mx.gluon.Trainer(net.collect_params(), opt.optimizer, opt_options, kvstore=opt.kvstore) L = ClusterLoss(num_classes=train_dataloader._dataset.num_classes() ) # Not hybridizable smoothing_constant = .01 # for tracking moving losses moving_loss = 0 best_results = [] # R@1, NMI for epoch in range(1, opt.epochs + 1): p_bar = tqdm(enumerate(train_dataloader), total=len(train_dataloader), desc=('[Run %d/%d] Epoch %d' % (run_id, opt.number_of_runs, epoch))) trainer.set_learning_rate(get_lr(opt.lr, epoch, steps, opt.factor)) if opt.optimizer == 'rmsprop': # exponential decay of gamme if epoch != 1: trainer._optimizer.gamma1 *= .94 for i, (data, labels) in p_bar: if opt.iteration_per_epoch > 0: data = data[0] labels = labels[0] labels = labels.astype('int32', copy=False) unique_labels = unique(mx.nd, labels).astype('float32') # extract label stats num_classes_batch = [] if len(context) == 1: num_classes_batch.append( mx.nd.array([unique_labels.size], dtype='int32')) else: slices = mx.gluon.utils.split_data(labels, len(context), batch_axis=0, even_split=False) for s in slices: num_classes_batch.append( mx.nd.array([np.unique(s.asnumpy()).size], dtype='int32')) data = mx.gluon.utils.split_and_load(data, ctx_list=context, batch_axis=0, even_split=False) label = mx.gluon.utils.split_and_load(labels, ctx_list=context, batch_axis=0, even_split=False) unique_labels = mx.gluon.utils.split_and_load(unique_labels, ctx_list=context, batch_axis=0, even_split=False) with ag.record(): losses = [] for x, y, uy, nc in zip(data, label, unique_labels, num_classes_batch): embs = net(x) losses.append( L( embs, y.astype('float32', copy=False), uy, mx.nd.arange(start=0, stop=x.shape[0], ctx=y.context))) for l in losses: l.backward() trainer.step(data[0].shape[0]) # Keep a moving average of the losses curr_loss = mx.nd.mean(mx.nd.concatenate(losses)).asscalar() moving_loss = ( curr_loss if ((i == 0) and (epoch == 1)) # starting value else (1 - smoothing_constant) * moving_loss + smoothing_constant * curr_loss) p_bar.set_postfix_str('Moving loss: %.4f' % moving_loss) logging.info('Moving loss: %.4f' % moving_loss) validation_results = validate(net, val_dataloader, context, use_threads=opt.num_workers > 0) for name, val_acc in validation_results: logging.info('[Epoch %d] validation: %s=%f' % (epoch, name, val_acc)) if (len(best_results) == 0) or (validation_results[0][1] > best_results[0][1]): best_results = validation_results if opt.save_model_prefix.lower() != 'none': filename = '%s.params' % opt.save_model_prefix logging.info('Saving %s.' % filename) net.save_parameters(filename) logging.info('New best validation: R@1: %f NMI: %f' % (best_results[0][1], best_results[-1][1])) return best_results
def train(net, opt, train_dataloader, val_dataloader, context, run_id): """Training function.""" if not opt.skip_pretrain_validation: validation_results = validate(net, val_dataloader, context, use_threads=opt.num_workers > 0) for name, val_acc in validation_results: logging.info('Pre-train validation: %s=%f' % (name, val_acc)) steps = parse_steps(opt.steps, opt.epochs, logging) opt_options = {'learning_rate': opt.lr, 'wd': opt.wd} if opt.optimizer == 'sgd': opt_options['momentum'] = 0.9 if opt.optimizer == 'adam': opt_options['epsilon'] = 1e-7 if opt.decrease_cnn_lr: logging.info('Setting embedding LR to %f' % (10.0 * opt.lr)) for p, v in net.encoder.collect_params().items(): v.lr_mult = 10.0 trainer = mx.gluon.Trainer(net.collect_params(), opt.optimizer, opt_options, kvstore=opt.kvstore) L = NPairsLoss(l2_reg=opt.l2reg_weight, symmetric=opt.symmetric_loss) if not opt.disable_hybridize: L.hybridize() smoothing_constant = .01 # for tracking moving losses moving_loss = 0 best_results = [] # R@1, NMI for epoch in range(1, opt.epochs + 1): p_bar = tqdm(enumerate(train_dataloader), total=len(train_dataloader), desc=('[Run %d/%d] Epoch %d' % (run_id, opt.number_of_runs, epoch))) trainer.set_learning_rate(get_lr(opt.lr, epoch, steps, opt.factor)) for i, batch in p_bar: anchors_batch = batch[0][0] # <N x I> positives_batch = batch[1][0] # <N x I> anchors = mx.gluon.utils.split_and_load(anchors_batch, ctx_list=context, batch_axis=0) positives = mx.gluon.utils.split_and_load(positives_batch, ctx_list=context, batch_axis=0) labels_batch = mx.gluon.utils.split_and_load(batch[2][0], ctx_list=context, batch_axis=0) anchor_embs = [] positives_embs = [] with ag.record(): for a, p in zip(anchors, positives): a_emb = net(a) p_emb = net(p) anchor_embs.append(a_emb) positives_embs.append(p_emb) anchors = mx.nd.concat(*anchor_embs, dim=0) positives = mx.nd.concat(*positives_embs, dim=0) loss = L(anchors, positives, labels_batch[0]) loss.backward() trainer.step(opt.batch_size / 2) curr_loss = mx.nd.mean(loss).asscalar() moving_loss = ( curr_loss if ((i == 0) and (epoch == 1)) # starting value else (1 - smoothing_constant) * moving_loss + smoothing_constant * curr_loss) # add current p_bar.set_postfix_str('Moving loss: %.4f' % moving_loss) validation_results = validate(net, val_dataloader, context, use_threads=opt.num_workers > 0) for name, val_acc in validation_results: logging.info('[Epoch %d] validation: %s=%f' % (epoch, name, val_acc)) if (len(best_results) == 0) or (validation_results[0][1] > best_results[0][1]): best_results = validation_results if opt.save_model_prefix.lower() != 'none': filename = '%s.params' % opt.save_model_prefix logging.info('Saving %s.' % filename) net.save_parameters(filename) logging.info('New best validation: R@1: %f NMI: %f' % (best_results[0][1], best_results[-1][1])) return best_results
def train(net, beta, opt, train_dataloader, val_dataloader, batch_size, context, run_id): """Training function.""" if not opt.skip_pretrain_validation: validation_results = validate(net, val_dataloader, context, use_threads=opt.num_workers > 0) for name, val_acc in validation_results: logging.info('Pre-train validation: %s=%f' % (name, val_acc)) steps = parse_steps(opt.steps, opt.epochs, logging) opt_options = {'learning_rate': opt.lr, 'wd': opt.wd} if opt.optimizer == 'sgd': opt_options['momentum'] = 0.9 if opt.optimizer == 'adam': opt_options['epsilon'] = 1e-7 trainer = gluon.Trainer(net.collect_params(), opt.optimizer, opt_options, kvstore=opt.kvstore) train_beta = not isinstance(beta, float) if train_beta: # Jointly train class-specific beta beta.initialize(mx.init.Constant(opt.beta), ctx=context) trainer_beta = gluon.Trainer(beta.collect_params(), 'sgd', { 'learning_rate': opt.lr_beta, 'momentum': 0.9 }, kvstore=opt.kvstore) loss = MarginLoss(batch_size, opt.batch_k, beta, margin=opt.margin, nu=opt.nu, train_beta=train_beta) if not opt.disable_hybridize: loss.hybridize() best_results = [] # R@1, NMI for epoch in range(1, opt.epochs + 1): prev_loss, cumulative_loss = 0.0, 0.0 # Learning rate schedule. trainer.set_learning_rate(get_lr(opt.lr, epoch, steps, opt.factor)) logging.info('Epoch %d learning rate=%f', epoch, trainer.learning_rate) if train_beta: trainer_beta.set_learning_rate( get_lr(opt.lr_beta, epoch, steps, opt.factor)) logging.info('Epoch %d beta learning rate=%f', epoch, trainer_beta.learning_rate) p_bar = tqdm(train_dataloader, desc='[Run %d/%d] Epoch %d' % (run_id, opt.number_of_runs, epoch), total=opt.iteration_per_epoch) for batch in p_bar: data = gluon.utils.split_and_load(batch[0][0], ctx_list=context, batch_axis=0) label = gluon.utils.split_and_load(batch[1][0].astype('float32'), ctx_list=context, batch_axis=0) Ls = [] with ag.record(): for x, y in zip(data, label): embedings = net(x) L = loss(embedings, y) Ls.append(L) cumulative_loss += nd.mean(L).asscalar() for L in Ls: L.backward() trainer.step(batch[0].shape[1]) if opt.lr_beta > 0.0: trainer_beta.step(batch[0].shape[1]) p_bar.set_postfix({'loss': cumulative_loss - prev_loss}) prev_loss = cumulative_loss logging.info('[Epoch %d] training loss=%f' % (epoch, cumulative_loss)) validation_results = validate(net, val_dataloader, context, use_threads=opt.num_workers > 0) for name, val_acc in validation_results: logging.info('[Epoch %d] validation: %s=%f' % (epoch, name, val_acc)) if (len(best_results) == 0) or (validation_results[0][1] > best_results[0][1]): best_results = validation_results if opt.save_model_prefix.lower() != 'none': filename = '%s.params' % opt.save_model_prefix logging.info('Saving %s.' % filename) net.save_parameters(filename) logging.info('New best validation: R@1: %f NMI: %f' % (best_results[0][1], best_results[-1][1])) return best_results