本章主要描述如何在RecStudio进行新模型的开发,如何像RecStudio开发者一样编排模型并通过 quickstart 运行。

(目前还没看过知识图谱、社交图谱的那些模型和文章,这个需要先看哪个模型然后才能写呀?或许我先看一下?)

(模型的运行逻辑)

对于retriever模型

模型类的填写

为了构建自己的模型,你需要完成自己的mymodel.py文件,在其中声明一个你的模型的类,在RecStudio中运行
Retriever类模型,你可以进行重载的常用的几个函数如下(函数内容为例子)

class MYMODEL(basemodel.BaseRetriever):
    def add_model_specific_args(parent_parser):
        # 你可以使用该参数添加自己的模型中需要的额外参数,格式如下
        parent_parser = basemodel.Recommender.add_model_specific_args(parent_parser)
        parent_parser.add_argument_group('MYMODEL')
        parent_parser.add_argument("--negative_count", type=int, default=1, help='negative sampling numbers')
        return parent_parser

    def _set_data_field(self, data):
        # 这个函数用于加载你要具体使用data的哪些特征进行训练,你可以通过访问data.field来获取已加载的特征
        data.use_field = data.field

    def _get_dataset_class():
        # 用于返回数据集的种类,默认MFDataset
        return MFDataset

    def _get_query_encoder(self, train_data):
        # 用于返回query的encoder,默认为Embedding
        return torch.nn.Embedding(train_data.num_users, self.embed_dim, padding_idx=0)

    def _get_item_encoder(self, train_data):
        # 用于返回item的encoder,默认为Embedding
       return torch.nn.Embedding(train_data.num_items, self.embed_dim, padding_idx=0)

    def _get_score_func(self):
        # 返回打分函数
        return scorer.InnerProductScorer()

    def _get_loss_func(self):
        # 返回损失函数
        return loss_func.BinaryCrossEntropyLoss

    def _get_sampler(self, train_data):
        # 返回采样器
        return sampler.UniformSampler(train_data.num_items)

    def _get_optimizers(self):
        # 返回优化器
        return None

    def _get_train_loaders(self, train_data, ddp=False) -> List:
        # 返回训练中所需使用的各种loader
        return [train_data.train_loader(
            batch_size = self.config['batch_size'],
            shuffle = True,
            num_workers = self.config['num_workers'],
            drop_last = False, ddp=ddp)]

    def current_epoch_trainloaders(self, nepoch) -> Tuple:
        # 当前epoch要使用的trainloader
        combine = False
        return self.trainloaders, combine

    def current_epoch_optimizers(self, nepoch) -> List:
        # 当前epoch要使用的优化器
        return self.optimizers

可能利用的数据

TODO: 1. 主要讲述清楚原始数据到数据集再到模型的训练/测试流程。可以考虑一个最全的数据(包括item user侧信息,kg,social graph信息等)。 2. 数据处理后的padding,num_items等个数的访问。 3. 以IRGAN等为例解释一些更复杂模型的设计。